Skip to content

Commit

Permalink
remove unused args and codepaths
Browse files Browse the repository at this point in the history
  • Loading branch information
sid committed Mar 17, 2021
1 parent 139efb1 commit 9adc752
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 146 deletions.
59 changes: 4 additions & 55 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def _get_parser(extra_args_provider=None):
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
parser = _add_zero_args(parser)
parser = _add_activation_checkpoint_args(parser)

Expand Down Expand Up @@ -139,9 +138,6 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.num_unique_layers <= args.num_layers
assert args.num_layers % args.num_unique_layers == 0, \
'num-layers should be divisible by num-unique-layers.'
if args.num_unique_layers < args.num_layers:
assert args.DDP_impl == 'local', \
'torch-DDP does not work with parameters sharing.'
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
Expand Down Expand Up @@ -196,7 +192,7 @@ def _add_network_size_args(parser):
'layers 1 and 2: '
' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
help='Transformer hidden size.')
group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.')
group.add_argument('--max-position-embeddings', type=int, default=None,
Expand Down Expand Up @@ -438,10 +434,6 @@ def _add_distributed_args(parser):
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo', 'mpi'],
help='Which backend to use for distributed training.')
group.add_argument('--DDP-impl', default='local',
choices=['local', 'torch'],
help='which DistributedDataParallel implementation '
'to use.')
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False,
Expand All @@ -456,14 +448,12 @@ def _add_distributed_args(parser):

def _add_validation_args(parser):
group = parser.add_argument_group(title='validation')

group.add_argument('--eval-iters', type=int, default=100,
help='Number of iterations to run for evaluation'
'validation/test for.')
group.add_argument('--eval-interval', type=int, default=1000,
help='Interval between running evaluation on '
'validation set.')

return parser


Expand All @@ -483,8 +473,6 @@ def _add_data_args(parser):
help='Path to the BPE merge file.')
group.add_argument('--seq-length', type=int, default=None,
help="Maximum sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1,
help='Probability of producing a short sequence.')
group.add_argument('--mmap-warmup', action='store_true',
Expand All @@ -499,13 +487,13 @@ def _add_data_args(parser):
choices=['lazy', 'cached', 'mmap', 'infer'],
help='Implementation of indexed datasets.')
group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.')
help='Reset position ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
help='Reset self attention mask after '
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.')
group.add_argument('--log-dir', type=str, help='Directory to store logs.')
group.add_argument('--log-dir', type=str, help='Directory to store logs.', default='./logs')

return parser

Expand All @@ -522,45 +510,6 @@ def _add_autoresume_args(parser):
return parser


def _add_realm_args(parser):
group = parser.add_argument_group(title='realm')

# network size
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')

# checkpointing
group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')

# data
group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset')
group.add_argument('--use-one-sent-docs', action='store_true',
help='Whether to use one sentence documents in ICT')

# training
group.add_argument('--report-topk-accuracies', nargs='+', default=[],
help="Which top-k accuracies to report (e.g. '1 5 20')")

# faiss index
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from')

# indexer
group.add_argument('--indexer-batch-size', type=int, default=128,
help='How large of batches to use when doing indexing jobs')
group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer report progress')
return parser


def _add_zero_args(parser):
"""Text generate arguments."""

Expand Down
40 changes: 0 additions & 40 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,43 +331,3 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print(' successfully loaded {}'.format(checkpoint_name))

return iteration


def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False):
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""

args = get_args()

if isinstance(model, torchDDP):
model = model.module

load_path = args.load if from_realm_chkpt else args.ict_load

tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())

# assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))

state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model']
if from_realm_chkpt and mpu.get_data_parallel_rank() == 0:
print(" loading ICT state dict from REALM", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']

if only_query_model:
ict_state_dict.pop('context_model')
if only_block_model:
ict_state_dict.pop('question_model')

model.load_state_dict(ict_state_dict)
torch.distributed.barrier()

if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))

return model
59 changes: 8 additions & 51 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,8 @@ def get_model(model_provider_func):
if args.deepspeed:
# DeepSpeed handles CUDA, FP16, and DDP components.
return model

# GPU allocation.
model.cuda(torch.cuda.current_device())

# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)

# Wrap model for distributed training."""
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = torchDDP(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
return model
if args.DDP_impl == 'local':
model = LocalDDP(model)
return model

raise NotImplementedError('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
else:
raise ValueError("Must be using deepspeed to run neox")


def get_optimizer(model):
Expand Down Expand Up @@ -206,18 +188,8 @@ def get_optimizer(model):
if args.deepspeed:
# fp16 wrapper is not required for DeepSpeed.
return optimizer, param_groups

# Wrap into fp16 optimizer.
if args.fp16:
optimizer = FP16_Optimizer(optimizer,
static_loss_scale=args.loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale,
dynamic_loss_args={
'scale_window': args.loss_scale_window,
'min_scale': args.min_scale,
'delayed_shift': args.hysteresis})

return optimizer, param_groups
else:
raise ValueError("Must be using deepspeed to run neox")


def get_learning_rate_scheduler(optimizer):
Expand Down Expand Up @@ -296,8 +268,7 @@ def setup_model_and_optimizer(model_provider_func):
if args.pipe_parallel_size > 0:
model.set_batch_fn(model.module._megatron_batch_fn)
else:
model.total_params = get_total_params(model)
print(f' > total params: {model.total_params}')
raise ValueError("Must be using deepspeed to run neox")

if args.load is not None:
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
Expand All @@ -322,24 +293,15 @@ def backward_step(optimizer, model, loss):
if args.deepspeed:
model.backward(loss)
else:
optimizer.zero_grad(set_grads_to_None=True)
if args.fp16:
optimizer.backward(loss, update_master_grads=False)
else:
loss.backward()
raise ValueError("Must be using deepspeed to run neox")
timers('backward-backward').stop()

if args.deepspeed:
# DeepSpeed backward propagation already addressed all reduce communication.
# Reset the timer to avoid breaking timer logs below.
timers('backward-allreduce').reset()
else:
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-allreduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
timers('backward-allreduce').stop()
raise ValueError("Must be using deepspeed to run neox")

if not args.deepspeed:
# Update master gradients.
Expand Down Expand Up @@ -384,12 +346,7 @@ def train_step(forward_step_func, data_iterator,
if args.deepspeed:
model.step()
else:
optimizer.step()
# Update learning rate.
if not (args.fp16 and optimizer.overflow):
lr_scheduler.step()
else:
skipped_iter = 1
raise ValueError("Must be using deepspeed to run neox")
timers('optimizer').stop()

return loss_reduced, skipped_iter
Expand Down

0 comments on commit 9adc752

Please sign in to comment.