Skip to content

Commit

Permalink
remove global vars checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Weinbach committed Apr 30, 2021
1 parent a230867 commit 1db0be0
Showing 1 changed file with 21 additions and 27 deletions.
48 changes: 21 additions & 27 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,19 @@
import torch
from glob import glob

from megatron import mpu, get_args
from megatron import get_args
from megatron import mpu
from megatron import print_rank_0
from megatron.utils import natural_sort


def check_checkpoint_args(checkpoint_args):
def check_checkpoint_args(neox_args, checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint."""
args = get_args() # TODO remove_global_vars

def _compare(arg_name):
checkpoint_value = getattr(checkpoint_args, arg_name)
args_value = getattr(args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the ' \
'input argument value ({}).'.format(
arg_name, checkpoint_value, args_value)
args_value = getattr(neox_args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the input argument value ({}).'.format(arg_name, checkpoint_value, args_value)
assert checkpoint_value == args_value, error_message

_compare('num_layers')
Expand Down Expand Up @@ -102,51 +98,49 @@ def delete_old_checkpoints(save_dir, n_to_keep):
pass


def save_ds_checkpoint(iteration, model, args):
def save_ds_checkpoint(iteration, model, neox_args):
"""Save a model checkpoint."""
sd = {'iteration': iteration}
# rng states.
if not args.no_save_rng:
if not neox_args.no_save_rng:
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
model.save_checkpoint(args.save, client_state=sd)
model.save_checkpoint(neox_args.save, client_state=sd)


def save_checkpoint(iteration, model, optimizer, lr_scheduler):
def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint."""
args = get_args() # TODO remove_global_vars

if args.deepspeed:
save_ds_checkpoint(iteration, model, args)
if neox_args.deepspeed:
save_ds_checkpoint(iteration, model, neox_args)
else:
raise ValueError('Must be using deepspeed to use neox')

# Wait so everyone is done (necessary)
torch.distributed.barrier()
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
tracker_filename = get_checkpoint_tracker_filename(neox_args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))

# Wait so everyone is done (necessary)
torch.distributed.barrier()
if args.keep_last_n_checkpoints is not None:
delete_old_checkpoints(args.save, args.keep_last_n_checkpoints)
if neox_args.keep_last_n_checkpoints is not None:
delete_old_checkpoints(neox_args.save, neox_args.keep_last_n_checkpoints)

# Wait so everyone is done (not necessary)
torch.distributed.barrier()


def load_checkpoint(model, optimizer, lr_scheduler):
def load_checkpoint(model, optimizer, lr_scheduler, neox_args):
"""Load a model checkpoint and return the iteration."""
args = get_args() # TODO remove_global_vars

# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(args.load)
tracker_filename = get_checkpoint_tracker_filename(neox_args.load)

# If no tracker file, return iteration zero.
if not os.path.isfile(tracker_filename):
Expand Down Expand Up @@ -174,9 +168,9 @@ def load_checkpoint(model, optimizer, lr_scheduler):
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)

if args.deepspeed:
load_optim_and_scheduler = not args.no_load_optim # TODO: These should be configured by separate args
checkpoint_name, state_dict = model.load_checkpoint(args.load,
if neox_args.deepspeed:
load_optim_and_scheduler = not neox_args.no_load_optim # TODO: These should be configured by separate args
checkpoint_name, state_dict = model.load_checkpoint(neox_args.load,
load_optimizer_states=load_optim_and_scheduler,
load_lr_scheduler_states=load_optim_and_scheduler)

Expand All @@ -188,7 +182,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
raise ValueError('Must be using deepspeed to use neox')

# Set iteration.
if args.finetune or release:
if neox_args.finetune or release:
iteration = 0
else:
try:
Expand All @@ -205,12 +199,12 @@ def load_checkpoint(model, optimizer, lr_scheduler):
# Check arguments.
if 'args' in state_dict:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args)
else:
print_rank_0('could not find arguments in the checkpoint ...')

# rng states.
if not release and not args.finetune and not args.no_load_rng:
if not release and not neox_args.finetune and not neox_args.no_load_rng:
try:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
Expand Down

0 comments on commit 1db0be0

Please sign in to comment.