Skip to content

Commit

Permalink
remove some unused codepaths
Browse files Browse the repository at this point in the history
  • Loading branch information
sid committed Apr 22, 2021
1 parent e1f7fcb commit 2c457e3
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 94 deletions.
76 changes: 2 additions & 74 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,42 +148,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if args.deepspeed:
save_ds_checkpoint(iteration, model, args)
else:
# Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP):
model = model.module
if mpu.get_data_parallel_rank() == 0:

# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 2.0
state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint()

# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()

# RNG states.
if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate()
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()

# Save.
checkpoint_name = get_checkpoint_name(args.save, iteration)
print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
format(torch.distributed.get_rank(), iteration,
checkpoint_name))
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
print(' successfully saved {}'.format(checkpoint_name))
raise ValueError('Must be using DeepSpeed')

# Wait so everyone is done (necessary)
torch.distributed.barrier()
Expand Down Expand Up @@ -248,45 +213,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
if mpu.get_data_parallel_rank() == 0:
print("Unable to load checkpoint.")
return iteration

else:
# Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))

# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
except BaseException:
print_rank_0('could not load the checkpoint')
sys.exit()
# Model.

model.load_state_dict(state_dict['model'])

# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
except KeyError:
print_rank_0(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
raise ValueError('Must be using DeepSpeed')

# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
Expand Down
7 changes: 4 additions & 3 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def print_split_stats(name, index):
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))

print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
Expand Down Expand Up @@ -110,6 +111,7 @@ def __len__(self):
# sample i --> [sample_idx[i], sample_idx[i+1])
return self.sample_idx.shape[0] - 1


def __getitem__(self, idx):
# Get the shuffled index.
idx = self.shuffle_idx[idx]
Expand Down Expand Up @@ -166,9 +168,8 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if (not os.path.isfile(doc_idx_filename)) or \
(not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):

(not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):
print_rank_0(' > WARNING: could not find index map files, building '
'the indices on rank 0 ...')
# doc-idx.
Expand Down
18 changes: 1 addition & 17 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,22 +316,6 @@ def backward_step(optimizer, model, loss):
else:
raise ValueError("Must be using deepspeed to run neox")

if not args.deepspeed:
# Update master gradients.
timers('backward-master-grad').start()
if args.fp16:
optimizer.update_master_grads()
timers('backward-master-grad').stop()

# Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start()
if args.clip_grad > 0:
if not args.fp16:
mpu.clip_grad_norm(model.parameters(), args.clip_grad)
else:
optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop()


def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
Expand Down Expand Up @@ -541,7 +525,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Logging.
loss_scale = None
if args.fp16:
loss_scale = optimizer.cur_scale if args.deepspeed else optimizer.loss_scale
loss_scale = optimizer.cur_scale
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
Expand Down
1 change: 1 addition & 0 deletions tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def main():
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)


startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
Expand Down

0 comments on commit 2c457e3

Please sign in to comment.