Skip to content

Commit

Permalink
Cleaned up load/save checkpoint printing
Browse files Browse the repository at this point in the history
  • Loading branch information
mohammad authored and deepakn94 committed Dec 19, 2020
1 parent b81cad6 commit 8a6e56b
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP):
model = model.module

if torch.distributed.get_rank() == 0:
print('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)

if mpu.get_data_parallel_rank() == 0:

# Arguments, iteration, and model.
Expand Down Expand Up @@ -137,14 +142,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):

# Save.
checkpoint_name = get_checkpoint_name(args.save, iteration)
print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
format(torch.distributed.get_rank(), iteration, checkpoint_name))
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
print(' successfully saved {}'.format(checkpoint_name))

# Wait so everyone is done (necessary)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
Expand Down Expand Up @@ -192,9 +197,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):

# Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
if torch.distributed.get_rank() == 0:
print(' loading checkpoint from {} at iteration {}'.format(
args.load, iteration), flush=True)

# Load the checkpoint.
try:
Expand Down Expand Up @@ -276,8 +281,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
sys.exit()

torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
if torch.distributed.get_rank() == 0:
print(' successfully loaded checkpoint from {} at iteration {}'.format(
args.load, iteration), flush=True)

return iteration

Expand Down

0 comments on commit 8a6e56b

Please sign in to comment.