Skip to content

Commit

Permalink
Improve Eval Harness (#471)
Browse files Browse the repository at this point in the history
* add dp

* fix all_gather

* add ability to specify iteration @ eval time

* fix

* debug

* fix

* fix

* fix

* fix

* fix

* remove debug

* add results_prefix arg

* fix issues with greedy until + last batch shape
  • Loading branch information
sdtblck committed Dec 20, 2021
1 parent bbbc5fb commit 3ad6195
Show file tree
Hide file tree
Showing 7 changed files with 1,162 additions and 554 deletions.
293 changes: 219 additions & 74 deletions eval_tasks/eval_adapter.py

Large diffs are not rendered by default.

226 changes: 147 additions & 79 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from megatron import print_rank_0
from megatron.utils import natural_sort
from megatron.text_generation_utils import get_batch, forward_model
from pathlib import Path
from pprint import pformat


def check_checkpoint_args(neox_args, checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
Expand All @@ -40,81 +43,110 @@ def check_checkpoint_args(neox_args, checkpoint_args):
assert isinstance(checkpoint_args, dict), "args stored in checkpoint is a dict"
for checkpoint_arg_name, checkpoint_arg_value in checkpoint_args.items():
args_value = getattr(neox_args, checkpoint_arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the currently set argument value ({}).'.format(checkpoint_arg_name, checkpoint_arg_value, args_value)
error_message = "{} value from checkpoint ({}) is not equal to the currently set argument value ({}).".format(
checkpoint_arg_name, checkpoint_arg_value, args_value
)
assert checkpoint_arg_value == args_value, error_message


def do_forward_pass(neox_args, model, inference=False):

# set to eval mode
model_was_in_train = model.training
model.eval()

# get context tokens
# always forward full batch size
context_tokens_tensor = torch.arange(2049).repeat((neox_args.train_micro_batch_size_per_gpu, 1)).cuda()
context_tokens_tensor = (
torch.arange(2049).repeat((neox_args.train_micro_batch_size_per_gpu, 1)).cuda()
)

# forward
if inference:
tokens, attention_mask, position_ids = get_batch(neox_args, context_tokens_tensor[:, :2048])
model_inputs = (tokens,
position_ids,
attention_mask,
torch.Tensor(),
)
tokens, attention_mask, position_ids = get_batch(
neox_args, context_tokens_tensor[:, :2048]
)
model_inputs = (
tokens,
position_ids,
attention_mask,
torch.Tensor(),
)
logits, _ = forward_model(neox_args, model, model_inputs)
elif neox_args.is_pipe_parallel:
data_iterator = iter([{"text": context_tokens_tensor}])
_, logits = model.eval_batch(data_iter=data_iterator, return_logits=True)
else:
tokens, attention_mask, position_ids = get_batch(neox_args, context_tokens_tensor[:, :2048])
tokens, attention_mask, position_ids = get_batch(
neox_args, context_tokens_tensor[:, :2048]
)
logits = model((tokens, position_ids, attention_mask))

# reset to train mode, if model was in training before
if model_was_in_train:
model.train()

if logits is not None:
logits = logits.detach().cpu()[0] # just return first batch item (they are all equal)
logits = logits.detach().cpu()[
0
] # just return first batch item (they are all equal)

return logits


def check_forward_pass(neox_args, model, checkpoint_logits, inference):
# do forward pass with loaded checkpoint
logits = do_forward_pass(neox_args=neox_args, model=model, inference=inference)

# check
if logits is not None and checkpoint_logits is not None: # this could be the case for non-final pipeline stages
if (
logits is not None and checkpoint_logits is not None
): # this could be the case for non-final pipeline stages
if not (logits == checkpoint_logits).all().item():
if mpu.get_data_parallel_rank() == 0:
print(" > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result")
assert torch.isclose(logits, checkpoint_logits).all().item(), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result"
print(
" > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result"
)
assert (
torch.isclose(logits, checkpoint_logits).all().item()
), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result"


def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)

def get_checkpoint_name(checkpoints_path, iteration,
release=False, mp_rank=None):

def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None):
"""A unified checkpoint name."""
if release:
directory = 'release'
directory = "release"
else:
directory = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None
else mp_rank),
'model_optim_rng.pt')
directory = "iter_{:07d}".format(iteration)
return os.path.join(
checkpoints_path,
directory,
"mp_rank_{:02d}".format(
mpu.get_model_parallel_rank() if mp_rank is None else mp_rank
),
"model_optim_rng.pt",
)


def delete_old_checkpoints(save_dir, n_to_keep):
if torch.distributed.get_rank() == 0:
ckpt_dir_regex = r'global_step[\d]*'
if save_dir.endswith('/'):
save_dir = save_dir.strip('/')
all_ckpts = natural_sort([i for i in glob(f'{save_dir}/*') if os.path.isdir(i)
and re.search(ckpt_dir_regex, i)])
ckpt_dir_regex = r"global_step[\d]*"
if save_dir.endswith("/"):
save_dir = save_dir.strip("/")
all_ckpts = natural_sort(
[
i
for i in glob(f"{save_dir}/*")
if os.path.isdir(i) and re.search(ckpt_dir_regex, i)
]
)
n_to_delete = len(all_ckpts) - n_to_keep
if n_to_delete > 0:
to_delete = all_ckpts[:n_to_delete]
Expand All @@ -125,29 +157,30 @@ def delete_old_checkpoints(save_dir, n_to_keep):
except FileNotFoundError:
pass


def save_ds_checkpoint(iteration, model, neox_args):
"""Save a model checkpoint."""
sd = {
'iteration': iteration,
'args': {
'num_layers': neox_args.num_layers,
'hidden_size': neox_args.hidden_size,
'num_attention_heads': neox_args.num_attention_heads,
'max_position_embeddings': neox_args.max_position_embeddings,
'make_vocab_size_divisible_by': neox_args.make_vocab_size_divisible_by,
'padded_vocab_size': neox_args.padded_vocab_size,
'tokenizer_type': neox_args.tokenizer_type,
'model_parallel_size': neox_args.model_parallel_size
}
}
"iteration": iteration,
"args": {
"num_layers": neox_args.num_layers,
"hidden_size": neox_args.hidden_size,
"num_attention_heads": neox_args.num_attention_heads,
"max_position_embeddings": neox_args.max_position_embeddings,
"make_vocab_size_divisible_by": neox_args.make_vocab_size_divisible_by,
"padded_vocab_size": neox_args.padded_vocab_size,
"tokenizer_type": neox_args.tokenizer_type,
"model_parallel_size": neox_args.model_parallel_size,
},
}
# rng states.
if not neox_args.no_save_rng:
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
sd["random_rng_state"] = random.getstate()
sd["np_rng_state"] = np.random.get_state()
sd["torch_rng_state"] = torch.get_rng_state()
sd["cuda_rng_state"] = torch.cuda.get_rng_state()
sd["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states()

if neox_args.checkpoint_validation_with_forward_pass:
logits = do_forward_pass(neox_args=neox_args, model=model)
sd['checkpoint_validation_logits'] = logits
Expand All @@ -172,7 +205,7 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
if neox_args.deepspeed:
save_ds_checkpoint(iteration, model, neox_args)
else:
raise ValueError('Must be using deepspeed to use neox')
raise ValueError("Must be using deepspeed to use neox")

# Wait so everyone is done (necessary)
torch.distributed.barrier()
Expand All @@ -182,72 +215,107 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
# Wait so everyone is done (not necessary)
torch.distributed.barrier()

def load_checkpoint(neox_args, model, optimizer, lr_scheduler, inference=False):
"""Load a model checkpoint and return the iteration."""

def load_checkpoint(
neox_args, model, optimizer, lr_scheduler, inference=False, iteration=None
):
"""Load a model checkpoint and return the iteration."""
if neox_args.deepspeed:
load_optim_and_scheduler = not neox_args.no_load_optim # TODO: These should be configured by separate args
load_optim_and_scheduler = (
not neox_args.no_load_optim
) # TODO: These should be configured by separate args
if neox_args.finetune:
load_optim_and_scheduler = False
checkpoint_name, state_dict = model.load_checkpoint(neox_args.load,
load_optimizer_states=load_optim_and_scheduler,
load_lr_scheduler_states=load_optim_and_scheduler)
if iteration is not None:
tag = f"global_step{iteration}"
else:
tag = None
checkpoint_name, state_dict = model.load_checkpoint(
neox_args.load,
load_optimizer_states=load_optim_and_scheduler,
load_lr_scheduler_states=load_optim_and_scheduler,
tag=tag,
)

if checkpoint_name is None:
# if an iteration is specified, we want to raise an error here rather than
# continuing silently, since we are trying to load a specific checkpoint
if iteration is not None:
available_checkpoints = sorted(
[
int(i.name.replace("global_step", ""))
for i in Path(neox_args.load).glob("global_step*")
]
)
raise ValueError(
f"Unable to load checkpoint for iteration {iteration}. \nAvailable iterations: {pformat(available_checkpoints)}"
)
if mpu.get_data_parallel_rank() == 0:
print("Unable to load checkpoint.")
return 0 # iteration 0, if not checkpoint loaded

return 0 # iteration 0, if not checkpoint loaded
else:
raise ValueError('Must be using deepspeed to use neox')
raise ValueError("Must be using deepspeed to use neox")

# Set iteration.
if neox_args.finetune:
iteration = 0
else:
iteration = state_dict.get('iteration') or state_dict.get("total_iters") # total_iters backward compatible with older checkpoints
iteration = state_dict.get("iteration") or state_dict.get(
"total_iters"
) # total_iters backward compatible with older checkpoints
if iteration is None:
raise ValueError(f'Unable to load iteration from checkpoint {checkpoint_name} with keys {state_dict.keys()}, exiting')
raise ValueError(
f"Unable to load iteration from checkpoint {checkpoint_name} with keys {state_dict.keys()}, exiting"
)

# Check arguments.
if 'args' in state_dict:
checkpoint_args = state_dict['args']
if "args" in state_dict:
checkpoint_args = state_dict["args"]
check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args)
print_rank_0(' > validated currently set args with arguments in the checkpoint ...')
print_rank_0(
" > validated currently set args with arguments in the checkpoint ..."
)
else:
print_rank_0(' > could not find arguments in the checkpoint for validation...')
print_rank_0(" > could not find arguments in the checkpoint for validation...")

# Check loaded checkpoint with forward pass
if neox_args.checkpoint_validation_with_forward_pass:
if "checkpoint_validation_logits" in state_dict:
check_forward_pass(
neox_args=neox_args,
model=model,
neox_args=neox_args,
model=model,
checkpoint_logits=state_dict["checkpoint_validation_logits"],
inference=inference
)
print_rank_0(' > validated loaded checkpoint with forward pass ...')
inference=inference,
)
print_rank_0(" > validated loaded checkpoint with forward pass ...")
else:
if mpu.get_data_parallel_rank() == 0:
print(' > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}'.format(checkpoint_name))
print(
" > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}".format(
checkpoint_name
)
)

# rng states.
if not neox_args.finetune and not neox_args.no_load_rng:
try:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
random.setstate(state_dict["random_rng_state"])
np.random.set_state(state_dict["np_rng_state"])
torch.set_rng_state(state_dict["torch_rng_state"])
torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
print_rank_0(
"Unable to load optimizer from checkpoint {}. "
"Specify --no-load-rng or --finetune to prevent "
"attempting to load the optimizer state, "
"exiting ...".format(checkpoint_name)
)
sys.exit()

torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
print(" successfully loaded {}".format(checkpoint_name))

return iteration
Loading

0 comments on commit 3ad6195

Please sign in to comment.