Skip to content

Commit

Permalink
add todos to calls of get_args()
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Weinbach committed Apr 30, 2021
1 parent 7d738b6 commit 834d6fd
Show file tree
Hide file tree
Showing 18 changed files with 67 additions and 69 deletions.
6 changes: 3 additions & 3 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint."""
args = get_args()
args = get_args() # TODO remove_global_vars

def _compare(arg_name):
checkpoint_value = getattr(checkpoint_args, arg_name)
Expand Down Expand Up @@ -117,7 +117,7 @@ def save_ds_checkpoint(iteration, model, args):

def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint."""
args = get_args()
args = get_args() # TODO remove_global_vars

if args.deepspeed:
save_ds_checkpoint(iteration, model, args)
Expand All @@ -143,7 +143,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):

def load_checkpoint(model, optimizer, lr_scheduler):
"""Load a model checkpoint and return the iteration."""
args = get_args()
args = get_args() # TODO remove_global_vars

# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(args.load)
Expand Down
12 changes: 4 additions & 8 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,11 @@ def initialize_megatron(allow_no_cuda=False):
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'

# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables()

args = get_args()
args = get_args() # TODO remove_global_vars

# torch.distributed initialization
def finish_mpu_init():
args = get_args()
args = get_args() # TODO remove_global_vars
# Pytorch distributed.
_initialize_distributed()

Expand Down Expand Up @@ -127,7 +123,7 @@ def setup_deepspeed_random_and_activation_checkpointing(args):

def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
args = get_args()
args = get_args() # TODO remove_global_vars

device_count = torch.cuda.device_count()
if torch.distributed.is_initialized():
Expand Down Expand Up @@ -211,7 +207,7 @@ def _set_random_seed(seed):

def _write_args_to_tensorboard():
"""Write arguments to tensorboard."""
args = get_args()
args = get_args() # TODO remove_global_vars
writer = get_tensorboard_writer()
if writer:
for arg in vars(args):
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class GPT2ModelPipe(PipelineModule, torch.nn.Module):
"""

def __init__(self, num_tokentypes=0, parallel_output=True, topology=None, inference=False, get_key_value=True):
args = get_args()
args = get_args() # TODO remove_global_vars

self._inference = inference
self.get_key_value = get_key_value if inference else False
Expand Down
10 changes: 5 additions & 5 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class GEGLU(torch.nn.Module):

def __init__(self):
super(GEGLU, self).__init__()
args = get_args()
args = get_args() # TODO remove_global_vars
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
if args.openai_gelu:
Expand Down Expand Up @@ -101,7 +101,7 @@ class ParallelMLP(torch.nn.Module):

def __init__(self, init_method, output_layer_init_method):
super(ParallelMLP, self).__init__()
args = get_args()
args = get_args() # TODO remove_global_vars

if args.geglu:
self.activation_type = "geglu"
Expand Down Expand Up @@ -163,7 +163,7 @@ class ParallelLinear(torch.nn.Module):

def __init__(self, parallel_output=True, init_method=torch.nn.init.xavier_normal_):
super(ParallelLinear, self).__init__()
args = get_args()
args = get_args() # TODO remove_global_vars
self.final_linear = mpu.RowParallelLinear(
args.hidden_size,
args.padded_vocab_size,
Expand All @@ -188,7 +188,7 @@ def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number, sparse=False,
rpe=None, rotary=False, get_key_value=False):
super(ParallelSelfAttention, self).__init__()
args = get_args()
args = get_args() # TODO remove_global_vars
self.fp16 = args.precision == "fp16"
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
Expand Down Expand Up @@ -465,7 +465,7 @@ class ParallelTransformerLayer(torch.nn.Module):
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number, sparse=False, rpe=None, rotary=False, get_key_value=False):

args = get_args()
args = get_args() # TODO remove_global_vars

super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self,
init_method,
num_tokentypes=0):
super(Embedding, self).__init__()
args = get_args()
args = get_args() # TODO remove_global_vars
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
Expand Down
18 changes: 9 additions & 9 deletions megatron/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

def get_batch(context_tokens):
"""Generate batch from context tokens."""
args = get_args()
args = get_args() # TODO remove_global_vars
tokenizer = get_tokenizer()

# Move to GPU.
Expand Down Expand Up @@ -115,7 +115,7 @@ def get_token_stream(model, context_tokens):
model: a Megatron model.
context_tokens: the prompt to complete.
"""
args = get_args()
args = get_args() # TODO remove_global_vars
tokenizer = get_tokenizer()

context_tokens, context_lengths = pad_batch(context_tokens,
Expand Down Expand Up @@ -163,7 +163,7 @@ def forward_model(model, model_inputs):
"""
# because someone at deepspeed decided pipeline modules couldn't use kwargs,
# we need to forward a pipe model by access model.module() instead of just model()
args = get_args()
args = get_args() # TODO remove_global_vars
torch.distributed.barrier()
if args.pipe_parallel_size <= 1:
return model.module(model_inputs)
Expand Down Expand Up @@ -197,7 +197,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
yields: tokens (completions from model), and lengths (lengths of completions)
"""
args = get_args()
args = get_args() # TODO remove_global_vars
tokenizer = get_tokenizer()

model.eval()
Expand Down Expand Up @@ -291,7 +291,7 @@ def generate_samples_from_prompt(model, text: Union[List[str], str]):
- 'text' (the completion)
- 'length' (the length of the completion)
"""
args = get_args()
args = get_args() # TODO remove_global_vars
tokenizer = get_tokenizer()

# type check
Expand Down Expand Up @@ -367,7 +367,7 @@ def generate_samples_input_from_file(model):
model: a Megatron model
"""
args = get_args()
args = get_args() # TODO remove_global_vars
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
Expand Down Expand Up @@ -395,7 +395,7 @@ def generate_samples_interactive(model, print_frequency=24):
model: a Megatron model
print_frequency: int, how often (in tokens) to print the output.
"""
args = get_args()
args = get_args() # TODO remove_global_vars
tokenizer = get_tokenizer()

context_count = 0
Expand Down Expand Up @@ -470,7 +470,7 @@ def generate_samples_unconditional(model):
- 'text' (the completion)
- 'length' (the length of the completion)
"""
args = get_args()
args = get_args() # TODO remove_global_vars
tokenizer = get_tokenizer()

num_samples = args.num_samples
Expand Down Expand Up @@ -517,7 +517,7 @@ def generate_and_write_samples_unconditional(model):
model: a Megatron model
"""
args = get_args()
args = get_args() # TODO remove_global_vars
assert args.genfile is not None
genfile = args.genfile

Expand Down
33 changes: 17 additions & 16 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@
import deepspeed


def pretrain(train_valid_test_dataset_provider, model_provider,
forward_step_func, extra_args_provider=None, args_defaults={}):
def pretrain(train_valid_test_dataset_provider, model_provider, forward_step_func, neox_args):
"""Main training program.
This function will run the followings in the order provided:
Expand All @@ -66,13 +65,15 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
neox_args: an instance of NeoXArgs containing the configuration for pretrain
"""

# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron()

args = get_args()

args = get_args() # TODO remove_global_vars
timers = get_timers()

# Model, optimizer, and learning rate.
Expand Down Expand Up @@ -117,7 +118,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,

def get_model(model_provider_func):
"""Build the model."""
args = get_args()
args = get_args() # TODO remove_global_vars

# Build model on cpu.
model = model_provider_func()
Expand All @@ -131,7 +132,7 @@ def get_model(model_provider_func):

def get_optimizer(model):
"""Set up the optimizer."""
args = get_args()
args = get_args() # TODO remove_global_vars
if args.no_load_optim:
return None, None
# Build parameter groups (weight decay and non-decay).
Expand Down Expand Up @@ -185,7 +186,7 @@ def get_optimizer(model):

def get_learning_rate_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()
args = get_args() # TODO remove_global_vars
if args.no_load_optim:
# TODO: this should be configured as a separate arg
return None
Expand Down Expand Up @@ -218,7 +219,7 @@ def get_learning_rate_scheduler(optimizer):

def setup_model_and_optimizer(model_provider_func):
"""Setup model and optimizer."""
args = get_args()
args = get_args() # TODO remove_global_vars

model = get_model(model_provider_func)
optimizer, param_groups = get_optimizer(model)
Expand Down Expand Up @@ -263,7 +264,7 @@ def setup_model_and_optimizer(model_provider_func):

def backward_step(optimizer, model, loss):
"""Backward step."""
args = get_args()
args = get_args() # TODO remove_global_vars
timers = get_timers()

# Backward pass.
Expand All @@ -285,7 +286,7 @@ def backward_step(optimizer, model, loss):
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
args = get_args() # TODO remove_global_vars
timers = get_timers()

# Pipeline parallelism schedules forward/backward/step
Expand Down Expand Up @@ -318,7 +319,7 @@ def train_step(forward_step_func, data_iterator,

def train_step_pipe(model, data_iterator):
"""Single training step with DeepSpeed's pipeline parallel engine. """
args = get_args()
args = get_args() # TODO remove_global_vars
timers = get_timers()

assert args.deepspeed
Expand All @@ -339,7 +340,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter, model, optimizer, noise_scale_logger):
"""Log training information such as losses, timing, etc."""

args = get_args()
args = get_args() # TODO remove_global_vars
timers = get_timers()

# Update losses.
Expand Down Expand Up @@ -474,7 +475,7 @@ def add_to_logging(name):
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator):
"""Train the model function."""
args = get_args()
args = get_args() # TODO remove_global_vars
timers = get_timers()

# Turn on training mode which enables dropout.
Expand Down Expand Up @@ -556,7 +557,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,

def evaluate(forward_step_func, data_iterator, model, verbose=False):
"""Evaluation."""
args = get_args()
args = get_args() # TODO remove_global_vars

# Turn on evaluation mode which disables dropout.
model.eval()
Expand Down Expand Up @@ -599,7 +600,7 @@ def evaluate_and_print_results(prefix, forward_step_func,
"""Helper function to evaluate and dump results on screen."""

# Pipeline parallelism needs eval_batch() instead of a simple forward().
args = get_args()
args = get_args() # TODO remove_global_vars
if args.is_pipe_parallel:
def _eval_helper(data_iter, _):
loss = model.eval_batch(data_iter)
Expand All @@ -624,7 +625,7 @@ def _eval_helper(data_iter, _):
def build_train_valid_test_data_iterators(
build_train_valid_test_datasets_provider):
"""XXX"""
args = get_args()
args = get_args() # TODO remove_global_vars

(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

Expand Down Expand Up @@ -749,7 +750,7 @@ def get_global_batch_size(args):


def get_flops(model, iter_time_s):
args = get_args()
args = get_args() # TODO remove_global_vars

world_size = torch.distributed.get_world_size()
global_batch_size = get_global_batch_size(args)
Expand Down
6 changes: 3 additions & 3 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def check_adlr_autoresume_termination(iteration, model,
"""Check for autoresume signal and exit if it is received."""
# to prevent circular import
from megatron.checkpointing import save_checkpoint
args = get_args()
args = get_args() # TODO remove_global_vars
autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
Expand All @@ -83,7 +83,7 @@ def make_data_loader(dataset):
"""Buld dataloader given an input dataset."""
if dataset is None:
return None
args = get_args()
args = get_args() # TODO remove_global_vars

# Data parallel arguments.
world_size = mpu.get_data_parallel_world_size()
Expand Down Expand Up @@ -226,7 +226,7 @@ def pipe_to_normal(model_engine, **kwargs):
"""
assert isinstance(model_engine, PipelineEngine), f"model engine {model_engine} not a PipelineEngine instance"
ret = DeepSpeedEngine(
args=get_args(),
args=get_args(), # TODO remove_global_vars
model=model_engine.module,
mpu=model_engine.module.mpu(),
dist_init_required=False,
Expand Down
Loading

0 comments on commit 834d6fd

Please sign in to comment.