Skip to content

Commit

Permalink
fix checkpointing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed Apr 25, 2021
1 parent 1f28e42 commit 71e2129
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 132 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,8 @@ dmypy.json
wandb/

# data files
data/
data/

# ckpt files
*.pt
*.ckpt
93 changes: 2 additions & 91 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,6 @@
from megatron import print_rank_0
from megatron.utils import natural_sort

_CHECKPOINT_VERSION = None


def set_checkpoint_version(value):
global _CHECKPOINT_VERSION
assert _CHECKPOINT_VERSION is None, \
"checkpoint version already set"
_CHECKPOINT_VERSION = value


def get_checkpoint_version():
global _CHECKPOINT_VERSION
return _CHECKPOINT_VERSION


def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint."""
Expand Down Expand Up @@ -148,42 +133,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if args.deepspeed:
save_ds_checkpoint(iteration, model, args)
else:
# Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP):
model = model.module
if mpu.get_data_parallel_rank() == 0:

# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 2.0
state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint()

# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()

# RNG states.
if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate()
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()

# Save.
checkpoint_name = get_checkpoint_name(args.save, iteration)
print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
format(torch.distributed.get_rank(), iteration,
checkpoint_name))
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
print(' successfully saved {}'.format(checkpoint_name))
raise ValueError('Must be using deepspeed to use neox')

# Wait so everyone is done (necessary)
torch.distributed.barrier()
Expand Down Expand Up @@ -250,46 +200,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
return iteration

else:
# Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))

# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
except BaseException:
print_rank_0('could not load the checkpoint')
sys.exit()
# Model.

model.load_state_dict(state_dict['model'])

# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
except KeyError:
print_rank_0(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()

# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
raise ValueError('Must be using deepspeed to use neox')

# Set iteration.
if args.finetune or release:
Expand Down
40 changes: 0 additions & 40 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu, exists
Expand Down Expand Up @@ -256,36 +255,6 @@ def __init__(self, attention_mask_func, init_method,
get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
checkpoint = deepspeed.checkpointing.checkpoint

def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
input_shape = mixed_layer.size()
if num_splits_first:
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """

intermediate_shape = input_shape[:-1] + \
(num_splits, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)

mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
else:
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """

intermediate_shape = input_shape[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, num_splits)

mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
mixed_layer = mixed_layer.view(*input_shape)

return mixed_layer

def forward(self, hidden_states, attention_mask, layer_past=None):

# hidden_states: [sq, b, h]
Expand All @@ -297,15 +266,6 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
Expand Down

0 comments on commit 71e2129

Please sign in to comment.