Skip to content

Commit

Permalink
Checkpoint validation (#324)
Browse files Browse the repository at this point in the history
* remove unused tracker file in checkpoint

* fix checkpoint args validation - actually save args in checkpoint

* checkpoint validation with forward pass

* don't check pipeline parell size when loading checkpoint

* Revert "don't check pipeline parell size when loading checkpoint"

This reverts commit fda18fc.

* Revert "checkpoint validation with forward pass"

This reverts commit 6f55a7f.

* don't check pipe_parallel_size when loading checkpoint

* redo checkpoint validation with forward pass

* checkpoint validation with forward pass for text gen

* simplify checkpoint validation with arange context tokens

Co-authored-by: Samuel Weinbach <[email protected]>
  • Loading branch information
sweinbach and Samuel Weinbach committed May 13, 2021
1 parent bb8222f commit d313cde
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 80 deletions.
3 changes: 3 additions & 0 deletions configs/local_setup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

"vocab-file": "data/gpt2-vocab.json",
"merge-file": "data/gpt2-merges.txt",

"save": "checkpoints",
"load": "checkpoints",
"checkpoint_validation_with_forward_pass": False,

"tensorboard-dir": "tensorboard",
"log-dir": "logs",
"use_wandb": True,
Expand Down
169 changes: 91 additions & 78 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,70 @@
from megatron import mpu
from megatron import print_rank_0
from megatron.utils import natural_sort

from megatron.text_generation_utils import get_batch, forward_model

def check_checkpoint_args(neox_args, checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint."""

def _compare(arg_name):
checkpoint_value = getattr(checkpoint_args, arg_name)
args_value = getattr(neox_args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the input argument value ({}).'.format(arg_name, checkpoint_value, args_value)
assert checkpoint_value == args_value, error_message
assert isinstance(checkpoint_args, dict), "args stored in checkpoint is a dict"
for checkpoint_arg_name, checkpoint_arg_value in checkpoint_args.items():
args_value = getattr(neox_args, checkpoint_arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the currently set argument value ({}).'.format(checkpoint_arg_name, checkpoint_arg_value, args_value)
assert checkpoint_arg_value == args_value, error_message

def do_forward_pass(neox_args, model, inference=False):

# set to eval mode
model_was_in_train = model.training
model.eval()

# get context tokens
# always forward full batch size
context_tokens_tensor = torch.arange(2049).repeat((neox_args.train_micro_batch_size_per_gpu, 1)).cuda()

# forward
if inference:
tokens, attention_mask, position_ids = get_batch(neox_args, context_tokens_tensor[:, :2048])
model_inputs = (tokens,
position_ids,
attention_mask,
torch.Tensor(),
)
logits, _ = forward_model(neox_args, model, model_inputs)
elif neox_args.is_pipe_parallel:
data_iterator = iter([{"text": context_tokens_tensor}])
_, logits = model.eval_batch(data_iter=data_iterator, return_logits=True)
else:
tokens, attention_mask, position_ids = get_batch(neox_args, context_tokens_tensor[:, :2048])
logits = model((tokens, position_ids, attention_mask))

# reset to train mode, if model was in training before
if model_was_in_train:
model.train()

_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
_compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
_compare('model_parallel_size')
if logits is not None:
logits = logits.detach().cpu()[0] # just return first batch item (they are all equal)

return logits

def check_forward_pass(neox_args, model, checkpoint_logits, inference):
# do forward pass with loaded checkpoint
logits = do_forward_pass(neox_args=neox_args, model=model, inference=inference)

# check
if logits is not None and checkpoint_logits is not None: # this could be the case for non-final pipeline stages
if not (logits == checkpoint_logits).all().item():
if mpu.get_data_parallel_rank() == 0:
print(" > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result")
assert torch.isclose(logits, checkpoint_logits).all().item(), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result"

def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)


def get_checkpoint_name(checkpoints_path, iteration,
release=False, mp_rank=None):
"""A unified checkpoint name."""
Expand All @@ -73,13 +108,6 @@ def get_checkpoint_name(checkpoints_path, iteration,
else mp_rank),
'model_optim_rng.pt')


def get_checkpoint_tracker_filename(checkpoints_path):
"""Tracker file rescords the latest chckpoint during
training to restart from."""
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


def delete_old_checkpoints(save_dir, n_to_keep):
if torch.distributed.get_rank() == 0:
ckpt_dir_regex = r'global_step[\d]*'
Expand All @@ -97,20 +125,35 @@ def delete_old_checkpoints(save_dir, n_to_keep):
except FileNotFoundError:
pass


def save_ds_checkpoint(iteration, model, neox_args):
"""Save a model checkpoint."""
sd = {'iteration': iteration}
sd = {
'iteration': iteration,
'args': {
'num_layers': neox_args.num_layers,
'hidden_size': neox_args.hidden_size,
'num_attention_heads': neox_args.num_attention_heads,
'max_position_embeddings': neox_args.max_position_embeddings,
'make_vocab_size_divisible_by': neox_args.make_vocab_size_divisible_by,
'padded_vocab_size': neox_args.padded_vocab_size,
'tokenizer_type': neox_args.tokenizer_type,
'model_parallel_size': neox_args.model_parallel_size
}
}
# rng states.
if not neox_args.no_save_rng:
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()

if neox_args.checkpoint_validation_with_forward_pass:
logits = do_forward_pass(neox_args=neox_args, model=model)
sd['checkpoint_validation_logits'] = logits

model.save_checkpoint(neox_args.save, client_state=sd)


def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint."""

Expand All @@ -119,14 +162,6 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
else:
raise ValueError('Must be using deepspeed to use neox')

# Wait so everyone is done (necessary)
torch.distributed.barrier()
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(neox_args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))

# Wait so everyone is done (necessary)
torch.distributed.barrier()
if neox_args.keep_last_n_checkpoints is not None:
Expand All @@ -135,39 +170,9 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
# Wait so everyone is done (not necessary)
torch.distributed.barrier()


def load_checkpoint(neox_args, model, optimizer, lr_scheduler):
def load_checkpoint(neox_args, model, optimizer, lr_scheduler, inference=False):
"""Load a model checkpoint and return the iteration."""

# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(neox_args.load)

# If no tracker file, return iteration zero.
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return 0

# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()

assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)

if neox_args.deepspeed:
load_optim_and_scheduler = not neox_args.no_load_optim # TODO: These should be configured by separate args
checkpoint_name, state_dict = model.load_checkpoint(neox_args.load,
Expand All @@ -177,34 +182,42 @@ def load_checkpoint(neox_args, model, optimizer, lr_scheduler):
if checkpoint_name is None:
if mpu.get_data_parallel_rank() == 0:
print("Unable to load checkpoint.")
return iteration
return 0 # iteration 0, if not checkpoint loaded
else:
raise ValueError('Must be using deepspeed to use neox')

# Set iteration.
if neox_args.finetune or release:
if neox_args.finetune:
iteration = 0
else:
try:
iteration = state_dict['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
except KeyError:
print_rank_0('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format(
checkpoint_name))
sys.exit()
iteration = state_dict.get('iteration') or state_dict.get("total_iters") # total_iters backward compatible with older checkpoints
if iteration is None:
raise ValueError('Unable to load iteration from checkpoint {}, exiting'.format(checkpoint_name))

# Check arguments.
if 'args' in state_dict:
checkpoint_args = state_dict['args']
check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args)
print_rank_0(' > validated currently set args with arguments in the checkpoint ...')
else:
print_rank_0('could not find arguments in the checkpoint ...')
print_rank_0(' > could not find arguments in the checkpoint for validation...')

# Check loaded checkpoint with forward pass
if neox_args.checkpoint_validation_with_forward_pass:
if "checkpoint_validation_logits" in state_dict:
check_forward_pass(
neox_args=neox_args,
model=model,
checkpoint_logits=state_dict["checkpoint_validation_logits"],
inference=inference
)
print_rank_0(' > validated loaded checkpoint with forward pass ...')
else:
if mpu.get_data_parallel_rank() == 0:
print(' > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}'.format(checkpoint_name))

# rng states.
if not release and not neox_args.finetune and not neox_args.no_load_rng:
if not neox_args.finetune and not neox_args.no_load_rng:
try:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,11 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Directory containing a model checkpoint.
"""

checkpoint_validation_with_forward_pass: bool = False
"""
save input and output of a forward pass with the checkpoint and validate after load
"""

save_interval: int = None
"""
Number of iterations between checkpoint saves.
Expand Down
2 changes: 1 addition & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def setup_model_and_optimizer(neox_args, inference=False, get_key_value=True):

if neox_args.load is not None:
neox_args.iteration = load_checkpoint(neox_args=neox_args, model=model, optimizer=optimizer,
lr_scheduler=lr_scheduler)
lr_scheduler=lr_scheduler, inference=inference)
print_rank_0(f'Loading checkpoint and starting from iteration {neox_args.iteration}')
else:
neox_args.iteration = 0
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pybind11==2.6.2
six
regex
numpy==1.20.2
-e git+git:https://github.com/EleutherAI/DeeperSpeed.git@75de46ca9edb98e1503f74c5e5aa84e0b0c8a05d#egg=deepspeed
-e git+git:https://github.com/EleutherAI/DeeperSpeed.git@3389e4f525445390aa141ef7c4db5376b13684f7#egg=deepspeed
mpi4py==3.0.3
wandb==0.10.28
einops==0.3.0
Expand Down

0 comments on commit d313cde

Please sign in to comment.