Skip to content

Commit

Permalink
fix eval_helper
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed Apr 28, 2021
1 parent c80212e commit 871e679
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
12 changes: 4 additions & 8 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import numpy as np

import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from glob import glob

from megatron import mpu, get_args
Expand Down Expand Up @@ -145,17 +144,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch.distributed.barrier()


def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
def load_checkpoint(model, optimizer, lr_scheduler):
"""Load a model checkpoint and return the iteration."""
args = get_args()
load_dir = getattr(args, load_arg)

if isinstance(model, torchDDP):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
tracker_filename = get_checkpoint_tracker_filename(args.load)

# If no tracker file, return iretation zero.
# If no tracker file, return iteration zero.
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
Expand Down Expand Up @@ -183,7 +179,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):

if args.deepspeed:
load_optim_and_scheduler = not args.no_load_optim # TODO: These should be configured by separate args
checkpoint_name, state_dict = model.load_checkpoint(load_dir,
checkpoint_name, state_dict = model.load_checkpoint(args.load,
load_optimizer_states=load_optim_and_scheduler,
load_lr_scheduler_states=load_optim_and_scheduler)

Expand Down
2 changes: 1 addition & 1 deletion megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def calculate_derived(self):
# Update 'is pipe parallel' flag
# if we set pipe_parallel_size to 0 or 1, GPT2ModelPipe.to_sequential() is called, and we run training with
# the sequential model without the PipelineModule wrapper to avoid the overhead it incurs
self.update_value("is_pipe_parallel", self.pipe_parallel_size > 1)
self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1)

############################################################################################################################
# start of validation functions
Expand Down
4 changes: 2 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def setup_model_and_optimizer(model_provider_func):
dist_init_required=False,
model_parameters=_model_params,
config_params=args.deepspeed_config,
mpu=mpu if args.is_pipe_parallel else None
mpu=mpu if not args.is_pipe_parallel else None
)
model.total_params = get_total_params(model.module)
print_rank_0(f' > total params: {"{:,}".format(model.total_params)}')
Expand Down Expand Up @@ -601,7 +601,7 @@ def evaluate_and_print_results(prefix, forward_step_func,
# Pipeline parallelism needs eval_batch() instead of a simple forward().
args = get_args()
if args.is_pipe_parallel:
def _eval_helper(data_iter):
def _eval_helper(data_iter, _):
loss = model.eval_batch(data_iter)
return None, {'lm loss': loss}
forward_step_func = _eval_helper
Expand Down

0 comments on commit 871e679

Please sign in to comment.