diff --git a/configs/coord_check_mup.yml b/configs/coord_check_mup.yml new file mode 100644 index 000000000..a09090029 --- /dev/null +++ b/configs/coord_check_mup.yml @@ -0,0 +1,103 @@ +{ + # parallelism settings + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + # model settings + "num_layers": 2, + "num_attention_heads": 4, + "seq_length": 2048, + "max_position_embeddings": 2048, + "pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + "output_layer_parallelism": "column", + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": true, + "bias_gelu_fusion": true, + + # # init methods + # "init_method": "small_init", + # "output_layer_init_method": "wang_init", + + # init methods + "init_method": "normal", + "output_layer_init_method": "scaled_normal", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "lr_decay_style": constant, + "warmup": 0, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 1260000000, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 1260000000, + "contiguous_gradients": true, + "cpu_offload": false + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 8, + "gradient_accumulation_steps": 1, + "data_impl": "mmap", + "num_workers": 1, + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "precision": "fp32", + # "fp16": { + # "fp16": true, + # "enabled": true, + # "loss_scale": 0, + # "loss_scale_window": 1000, + # "hysteresis": 2, + # "min_loss_scale": 1 + # }, + + # misc. training settings + "train_iters": 10, + "log_interval": 1, + "distributed_backend": "nccl", + + "coord_check": true, + "coord_check_nsteps": 5, + "coord_check_nseeds": 1, + "use_mup": true, + # base lr + "mup_lr": 0.01, + # base sigma + "mup_std": 0.08, + # base size + "mup_d_model_base": 256, + "mup_hidden_size": 256, + + "tokenizer_type": "HFTokenizer", + "vocab-file": "/mnt/ssd-1/lintang/09-mup-neox/20B_tokenizer.json", + "data-path": "/mnt/ssd-1/lintang/09-mup-neox/data/enwik8/enwik8_text_document", + "mup_save": "/mnt/ssd-1/lintang/09-mup-neox/mup_results", + +} diff --git a/configs/coord_check_sp.yml b/configs/coord_check_sp.yml new file mode 100644 index 000000000..66573892d --- /dev/null +++ b/configs/coord_check_sp.yml @@ -0,0 +1,101 @@ +{ + # parallelism settings + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + # model settings + "num_layers": 2, + "num_attention_heads": 4, + "seq_length": 2048, + "max_position_embeddings": 2048, + "pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + "output_layer_parallelism": "column", + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": true, + "bias_gelu_fusion": true, + + # # init methods + # "init_method": "small_init", + # "output_layer_init_method": "wang_init", + + # init methods + "init_method": "normal", + "output_layer_init_method": "scaled_normal", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.01, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "lr_decay_style": constant, + "warmup": 0, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 1260000000, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 1260000000, + "contiguous_gradients": true, + "cpu_offload": false + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 8, + "gradient_accumulation_steps": 1, + "data_impl": "mmap", + "num_workers": 1, + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "precision": "fp32", + # "fp16": { + # "fp16": true, + # "enabled": true, + # "loss_scale": 0, + # "loss_scale_window": 1000, + # "hysteresis": 2, + # "min_loss_scale": 1 + # }, + + # misc. training settings + "train_iters": 10, + "log_interval": 1, + "distributed_backend": "nccl", + + "coord_check": true, + "coord_check_nsteps": 5, + "coord_check_nseeds": 1, + # "use_mup": true, + # base sigma + "init_method_std": 0.08, + # base size + "hidden_size": 256, + + "tokenizer_type": "HFTokenizer", + "vocab-file": "/mnt/ssd-1/lintang/09-mup-neox/20B_tokenizer.json", + "data-path": "/mnt/ssd-1/lintang/09-mup-neox/data/enwik8/enwik8_text_document", + "mup_save": "/mnt/ssd-1/lintang/09-mup-neox/mup_results", + +} diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 8bc184720..0ebb1b063 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 4e37645 + Default = 6a8ad71 current git hash of repository @@ -274,6 +274,7 @@ Model Arguments Default = None Transformer hidden size. + When using muP, this is d_model @@ -558,6 +559,7 @@ Model Arguments Default = 0.02 Standard deviation of the zero mean normal distribution used for weight initialization. + When using muP this is the base std @@ -830,6 +832,7 @@ Optimizer Arguments Default = None Max Learning rate during training + When using muP, this is the base learning rate @@ -1778,7 +1781,42 @@ Training Arguments Default = False - Whether to use Microsoft's Mup https://github.com/microsoft/mup + Whether to use muP + + + +- **mup_save**: str + + Default = None + + Path to save results when using muP + + + +- **mup_lr**: float + + Default = None + + An alias parameter for lr, + if not None will override lr + + + +- **mup_std**: float + + Default = None + + An alias parameter for init_method_std, + if not None will override init_method_std + + + +- **mup_hidden_size**: int + + Default = None + + An alias parameter for hidden_size, + if not None will override hidden_size @@ -1790,68 +1828,68 @@ Training Arguments -- **save_base_shapes**: bool +- **coord_check_nsteps**: int - Default = False + Default = 10 - Whether to save base shapes for mup. This will save the shapes to the path specified in base-shapes-file. + Number of steps to do for the coordinate check -- **base_shapes_file**: str +- **coord_check_nseeds**: int - Default = None + Default = 5 - Path to the base shapes to save to/load from + Number of repetition for each size in coordinate check -- **mup_init_scale**: float +- **save_base_shapes**: bool - Default = 1.0 + Default = False - Initialization scale: All the parameters are multiplied by this value + Whether to save base shapes for mup. This will save the shapes to the path specified in base-shapes-file. -- **mup_attn_temp**: float +- **base_shapes_file**: str - Default = 1.0 + Default = None - Attention temperature: Reciprocal of the multiplier applied to the input to attention softmax + Path to the base shapes to save to/load from -- **mup_output_temp**: float +- **mup_embedding_multiplier**: float Default = 1.0 - Output temperature: Reciprocal of the multiplier applied to the input to softmax that - produces the distribution over output tokens. + Embedding output multiplier -- **mup_embedding_mult**: float +- **mup_output_multiplier**: float Default = 1.0 - Scalar by which we multiply the output of the embedding layer + Output logits multiplier -- **mup_rp_embedding_mult**: float +- **mup_width_multiplier**: float - Default = 1.0 + Default = None - Scalar by which we multiply vectors representing relative position + Manually set the layer width multiplier (d_model/d_model,base) -- **mup_width_scale**: int +- **mup_d_model_base**: int - Default = 2 + Default = 256 - What to scale width by when creating the delta model for mup + d_model,base + Proxy (base) model's layer width diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index 9db951aa0..4ae18d49b 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -37,6 +37,7 @@ def __init__( use_checkpoint_lr_scheduler=True, override_lr_scheduler=False, use_mup=False, + mup_width_multiplier=1, ): # Class values. @@ -51,6 +52,7 @@ def __init__( self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler self.use_mup = use_mup + self.mup_width_multiplier = mup_width_multiplier if self.override_lr_scheduler: assert not self.use_checkpoint_lr_scheduler, ( "both override and " "use-checkpoint are set." @@ -95,8 +97,8 @@ def step(self, step_num=None): self.num_iters = step_num new_lr = self.get_lr() for group in self.optimizer.param_groups: - if self.use_mup and "width_mult" in group: - group["lr"] = new_lr / group["width_mult"] + if self.use_mup and ("lr_adjust" in group) and group["lr_adjust"] is True: + group["lr"] = new_lr / self.mup_width_multiplier else: group["lr"] = new_lr diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 9c86d98d3..c3378810c 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -114,14 +114,15 @@ def __init__( use_cache=False, ): self.neox_args = neox_args - self.use_cache = use_cache self.parallel_output = parallel_output self.hidden_size = self.neox_args.hidden_size self.num_tokentypes = num_tokentypes - self.init_method, self.output_layer_init_method = get_init_methods( - self.neox_args - ) + ( + self.init_method, + self.input_embedding_init_method, + self.output_layer_init_method, + ) = get_init_methods(self.neox_args) self.__topology__ = topology self.specs = [] @@ -186,6 +187,7 @@ def init_specs(self): # Embedding layer # input will be (input_ids, position_ids, attention_mask) + # TODO Initialized weights here should not be divided by m_width if weight_tying: self.specs.append( TiedLayerSpec( @@ -196,7 +198,7 @@ def init_specs(self): self.neox_args.padded_vocab_size, self.neox_args.max_position_embeddings, self.neox_args.hidden_dropout, - self.init_method, + self.input_embedding_init_method, self.num_tokentypes, tied_weight_attr="word_embeddings_weight", ) @@ -210,7 +212,7 @@ def init_specs(self): self.neox_args.padded_vocab_size, self.neox_args.max_position_embeddings, self.neox_args.hidden_dropout, - self.init_method, + self.input_embedding_init_method, self.num_tokentypes, ) ) @@ -289,16 +291,12 @@ def init_specs(self): def _logits_helper(embedding, lm_output): """Just a wrapper to massage inputs/outputs from pipeline.""" - if self.neox_args.use_mup: - # Since we're using pipeline parallelism, we can't directly use MuReadout. Instead, use this workaround that does the same thing as MuReadout. - # https://github.com/microsoft/mup/issues/6#issuecomment-1082156274 - lm_output = ( - lm_output - / self.tied_modules.embed.word_embeddings.weight.infshape.width_mult() - ) logits = parallel_lm_logits( - lm_output, embedding.word_embeddings_weight, self.parallel_output + lm_output, + embedding.word_embeddings_weight, + self.parallel_output, + self.neox_args, ) return logits diff --git a/megatron/model/init_functions.py b/megatron/model/init_functions.py index 86a003dbd..152acc3b6 100644 --- a/megatron/model/init_functions.py +++ b/megatron/model/init_functions.py @@ -16,23 +16,12 @@ import torch -try: - import mup -except ImportError: - pass +def init_method_normal(sigma): + """Init method based on N(0, sigma^2).""" -def init_method_normal(sigma, use_mup_outer=False, mup_init_scale=1.0): - """Init method based on N(0, sigma).""" - - def init_(tensor, use_mup=use_mup_outer): - if use_mup: - mup.init.normal_(tensor, mean=0.0, std=sigma) - with torch.no_grad(): - tensor.mul_(mup_init_scale) - return tensor - else: - return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) return init_ @@ -40,8 +29,6 @@ def init_(tensor, use_mup=use_mup_outer): def scaled_init_method_normal( sigma, num_layers, - use_mup_outer=False, - mup_init_scale=1.0, num_residuals_per_layer=2, ): """Init method based on N(0, sigma/sqrt(2*num_layers). @@ -51,14 +38,8 @@ def scaled_init_method_normal( """ std = sigma / math.sqrt(num_residuals_per_layer * num_layers) - def init_(tensor, use_mup=use_mup_outer): - if use_mup: - mup.init.normal_(tensor, mean=0.0, std=std) - with torch.no_grad(): - tensor.mul_(mup_init_scale) - return tensor - else: - return torch.nn.init.normal_(tensor, mean=0.0, std=std) + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) return init_ @@ -95,12 +76,12 @@ def _orthogonal(tensor, gain=1): return tensor -def orthogonal_init_method(n_layers=1, use_mup=False, mup_init_scale=1.0): +def orthogonal_init_method(n_layers=1, mup_width_multiplier=1.0): """Fills the input Tensor with a (semi) orthogonal matrix, as described in Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe, A. et al. (2013) Optionally scaling by number of layers possible, as introduced in OBST - Nestler et. al. (2021, to be released)""" - if use_mup: + if mup_width_multiplier != 1: raise ValueError( "Orthogonal init needs to be patched to support mup. Disable mup or use a different init method to avoid this error" ) @@ -111,105 +92,93 @@ def init_(tensor): return init_ -def xavier_uniform_init_method(use_mup_outer=False, mup_init_scale=1.0): +def xavier_uniform_init_method(mup_width_multiplier=1.0): """Fills the input Tensor with values according to the method described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution.""" - def init_(tensor, use_mup=use_mup_outer): - if use_mup: - mup.init.xavier_uniform_(tensor) + def init_(tensor, mup_width_multiplier=mup_width_multiplier): + init_weight = torch.nn.init.xavier_uniform_(tensor) + if mup_width_multiplier != 1: with torch.no_grad(): - tensor.mul_(mup_init_scale) - return tensor - else: - return torch.nn.init.xavier_uniform_(tensor) + init_weight.div_(math.sqrt(mup_width_multiplier)) + return init_weight return init_ -def xavier_normal_init_method(use_mup_outer=False, mup_init_scale=1.0): +def xavier_normal_init_method(mup_width_multiplier=1.0): """Fills the input Tensor with values according to the method described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution.""" - def init_(tensor, use_mup=use_mup_outer): - if use_mup: - mup.init.xavier_normal_(tensor) + def init_(tensor, mup_width_multiplier=mup_width_multiplier): + init_weight = torch.nn.init.xavier_normal_(tensor) + if mup_width_multiplier != 1: with torch.no_grad(): - tensor.mul_(mup_init_scale) - return tensor - else: - return torch.nn.init.xavier_normal_(tensor) + init_weight.div_(math.sqrt(mup_width_multiplier)) + return init_weight return init_ -def small_init_init_method(dim, use_mup_outer=False, mup_init_scale=1.0): +def small_init_init_method(dim, mup_width_multiplier=1.0): """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution.""" - std = math.sqrt(2 / (5 * dim)) + std = math.sqrt(2 / (5 * dim)) / math.sqrt(args.mup_width_multiplier) - def init_(tensor, use_mup=use_mup_outer): - if use_mup: - mup.init.normal_(tensor, mean=0.0, std=std) - with torch.no_grad(): - tensor.mul_(mup_init_scale) - return tensor - else: - return torch.nn.init.normal_(tensor, mean=0.0, std=std) + def init_(tensor, mup_width_multiplier=mup_width_multiplier): + init_weight = torch.nn.init.normal_(tensor, mean=0.0, std=std) + return init_weight return init_ -def wang_init_method(n_layers, dim, use_mup_outer=False, mup_init_scale=1.0): - std = 2 / n_layers / math.sqrt(dim) +def wang_init_method(n_layers, dim, mup_width_multiplier=1.0): + std = 2 / n_layers / math.sqrt(dim) / math.sqrt(args.mup_width_multiplier) - def init_(tensor, use_mup=use_mup_outer): - if use_mup: - mup.init.normal_(tensor, mean=0.0, std=std) - with torch.no_grad(): - tensor.mul_(mup_init_scale) - return tensor - else: - return torch.nn.init.normal_(tensor, mean=0.0, std=std) + def init_(tensor, mup_width_multiplier=mup_width_multiplier): + init_weight = torch.nn.init.normal_(tensor, mean=0.0, std=std) + return init_weight return init_ def get_init_methods(args): - - if args.use_mup: - try: - import mup - except ModuleNotFoundError: - print("Please install mup https://github.com/microsoft/mup") - raise Exception - - def _get(name): + def _get(name, use_mup=False): if name == "normal": + sigma = args.init_method_std + if use_mup: + sigma = sigma / math.sqrt(args.mup_width_multiplier) return init_method_normal( - args.init_method_std, args.use_mup, args.mup_init_scale + sigma=sigma, ) elif name == "scaled_normal": - return scaled_init_method_normal( - args.init_method_std, args.num_layers, args.use_mup, args.mup_init_scale - ) + sigma = args.init_method_std + if use_mup: + sigma = sigma / math.sqrt(args.mup_width_multiplier) + return scaled_init_method_normal(sigma=sigma, num_layers=args.num_layers) elif name == "orthogonal": - return orthogonal_init_method(args.use_mup, args.mup_init_scale) + return orthogonal_init_method(args.mup_width_multiplier if use_mup else 1.0) elif name == "scaled_orthogonal": return orthogonal_init_method( - args.num_layers, args.use_mup, args.mup_init_scale + args.num_layers, args.mup_width_multiplier if use_mup else 1.0 ) elif name == "xavier_uniform": - return xavier_uniform_init_method(args.use_mup, args.mup_init_scale) + return xavier_uniform_init_method( + args.mup_width_multiplier if use_mup else 1.0 + ) elif name == "xavier_normal": - return xavier_normal_init_method(args.use_mup, args.mup_init_scale) + return xavier_normal_init_method( + args.mup_width_multiplier if use_mup else 1.0 + ) elif name == "wang_init": return wang_init_method( - args.num_layers, args.hidden_size, args.use_mup, args.mup_init_scale + args.num_layers, + args.hidden_size, + args.mup_width_multiplier if use_mup else 1.0, ) elif name == "small_init": return small_init_init_method( - args.hidden_size, args.use_mup, args.mup_init_scale + args.hidden_size, args.mup_width_multiplier if use_mup else 1.0 ) elif name == "single_residual_scaled_normal": # mamba init uses scaled_normal but no need for 2 * num_layers @@ -224,4 +193,8 @@ def _get(name): else: raise NotImplementedError(f"Unknown init method {name}") - return _get(args.init_method), _get(args.output_layer_init_method) + return ( + _get(args.init_method, use_mup=args.use_mup), + _get(args.init_method), + _get(args.output_layer_init_method, use_mup=args.use_mup), + ) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c6bf619f8..6821d73ef 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -252,9 +252,11 @@ def __init__( init_method=init_method, gather_output=not parallel_output, skip_bias_add=False, - mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here ) + self.neox_args = neox_args + self.is_last_layer = is_last_layer + # else: # print( # 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.' @@ -269,11 +271,18 @@ def __init__( # init_method=init_method, # parallel_output=parallel_output, # skip_bias_add=False, - # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here # ) def forward(self, hidden_states): - return self.final_linear(hidden_states) + logits = self.final_linear(hidden_states) + + if self.is_last_layer: + _logits, *_args = logits + if self.neox_args.use_mup: + _logits /= self.neox_args.mup_width_multiplier + _logits *= self.neox_args.mup_output_multiplier + logits = (_logits, *_args) + return logits class ParallelSelfAttention(nn.Module): @@ -370,14 +379,15 @@ def __init__( ) coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if neox_args.use_mup: + self.norm_factor = self.hidden_size_per_attention_head + else: + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: coeff = max(1, self.layer_number) self.norm_factor *= coeff - if neox_args.use_mup: - self.norm_factor = self.hidden_size_per_attention_head - self.rpe = rpe if self.pos_emb == "alibi": @@ -1212,7 +1222,9 @@ def forward(self, args): return self.norm(args) -def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): +def parallel_lm_logits( + input_, word_embeddings_weight, parallel_output, bias=None, args=None +): """LM logits using word embedding weights.""" # Parallel logits. input_parallel = mpu.copy_to_model_parallel_region(input_) @@ -1223,6 +1235,10 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=Non else: logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) + if args is not None and args.use_mup: + logits_parallel /= args.mup_width_multiplier + logits_parallel *= args.mup_output_multiplier + # Gather if needed. if parallel_output: return logits_parallel diff --git a/megatron/model/utils.py b/megatron/model/utils.py index c3da2ce8b..4e7d481c4 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -18,6 +18,7 @@ """Utilities for models.""" import torch +from megatron.mpu import VocabParallelEmbedding from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm from megatron.model.fused_softmax import SoftmaxFusionTypes from types import GeneratorType @@ -28,49 +29,105 @@ def get_params_for_weight_decay_optimization(module, neox_args): """Divide params into with-weight-decay and without-weight-decay groups. Layernorms and biases will have no weight decay but the rest will. """ - weight_decay_params = {"params": [], "name": "weight_decay_params"} - no_weight_decay_params = { + + lr_adjust_weight_decay_params = {"params": [], "lr_adjust": True, "name": "lr_adjust_weight_decay_params"} + lr_adjust_no_weight_decay_params = { "params": [], + "lr_adjust": True, "weight_decay": 0.0, - "name": "no_weight_decay_params", + "name": "lr_adjust_no_weight_decay_params", } + no_lr_adjust_weight_decay_params = {"params": [], "lr_adjust": False, "name": "no_lr_adjust_weight_decay_params"} + no_lr_adjust_no_weight_decay_params = { + "params": [], + "lr_adjust": False, + "weight_decay": 0.0, + "name": "no_lr_adjust_no_weight_decay_params", + } + for module_ in module.modules(): - if any( - [ - isinstance(module_, LayerNorm), - isinstance(module_, RMSNorm), - isinstance(module_, ScaleNorm), - ] - ) or ( - neox_args.weight_decay == 0.0 - ): # also include all parameters here if no weight decay is being done - no_weight_decay_params["params"].extend( - [p for p in list(module_._parameters.values()) if p is not None] - ) - else: - weight_decay_params["params"].extend( + if neox_args.weight_decay == 0.0: + if any( [ - p - for n, p in list(module_._parameters.items()) - if p is not None - and n != "bias" - and not getattr(p, "_no_weight_decay", False) + isinstance(module_, LayerNorm), + isinstance(module_, RMSNorm), + isinstance(module_, ScaleNorm), + isinstance(module_, VocabParallelEmbedding), ] - ) - no_weight_decay_params["params"].extend( + ): + no_lr_adjust_no_weight_decay_params["params"].extend( + [p for p in list(module_._parameters.values()) if p is not None] + ) + else: + no_lr_adjust_no_weight_decay_params["params"].extend( + [ + p + for n, p in list(module_._parameters.items()) + if p is not None and (n == "bias" or getattr(p, "_no_weight_decay", False)) + ] + ) + lr_adjust_no_weight_decay_params["params"].extend( + [ + p + for n, p in list(module_._parameters.items()) + if p is not None and (n != "bias" or getattr(p, "_no_weight_decay", False)) + ] + ) + else: + if any( [ - p - for n, p in list(module_._parameters.items()) - if p is not None - and (n == "bias" or getattr(p, "_no_weight_decay", False)) + isinstance(module_, LayerNorm), + isinstance(module_, RMSNorm), + isinstance(module_, ScaleNorm), ] - ) + ): + no_lr_adjust_no_weight_decay_params["params"].extend( + [p for p in list(module_._parameters.values()) if p is not None] + ) + + elif isinstance(module_, VocabParallelEmbedding): + no_lr_adjust_weight_decay_params["params"].extend( + [ + p + for n, p in list(module_._parameters.items()) + if p is not None and n != "bias" and not getattr(p, "_no_weight_decay", False) + ] + ) + no_lr_adjust_no_weight_decay_params["params"].extend( + [ + p + for n, p in list(module_._parameters.items()) + if p is not None and (n == "bias" or getattr(p, "_no_weight_decay", False)) + ] + ) + else: + lr_adjust_weight_decay_params["params"].extend( + [ + p + for n, p in list(module_._parameters.items()) + if p is not None and n != "bias" and not getattr(p, "_no_weight_decay", False) + ] + ) + lr_adjust_no_weight_decay_params["params"].extend( + [ + p + for n, p in list(module_._parameters.items()) + if p is not None and (n == "bias" or getattr(p, "_no_weight_decay", False)) + ] + ) + if neox_args.weight_decay == 0.0: # only return a single param group # with onebitadam, we want to minimize the calls to compressed_allreduce. Every param group calls it once. # to avoid this, only use a single param group when weight decay is off. - return [no_weight_decay_params] - return weight_decay_params, no_weight_decay_params + # return (lr_adjust_no_weight_decay_params, no_lr_adjust_no_weight_decay_params) + return (lr_adjust_no_weight_decay_params, no_lr_adjust_no_weight_decay_params) + return ( + lr_adjust_weight_decay_params, + lr_adjust_no_weight_decay_params, + no_lr_adjust_weight_decay_params, + no_lr_adjust_no_weight_decay_params, + ) def exists(x): diff --git a/megatron/model/word_embeddings.py b/megatron/model/word_embeddings.py index f7372bc55..22ea5989d 100644 --- a/megatron/model/word_embeddings.py +++ b/megatron/model/word_embeddings.py @@ -50,9 +50,9 @@ def __init__( self.hidden_size = hidden_size self.init_method = init_method self.num_tokentypes = num_tokentypes - self.use_mup = neox_args.use_mup - self.mup_embedding_mult = neox_args.mup_embedding_mult - self.mup_rp_embedding_mult = neox_args.mup_rp_embedding_mult + self.mup_embedding_multiplier = ( + float(neox_args.mup_embedding_multiplier) if neox_args.use_mup else 1.0 + ) # Word embeddings (parallel). self.word_embeddings = mpu.VocabParallelEmbedding( @@ -142,7 +142,6 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): # OPT always adds 2 for some reason, according to the HF implementation position_ids = position_ids + self.opt_pos_emb_offset position_embeddings = self.position_embeddings(position_ids) - position_embeddings.mul_(self.mup_rp_embedding_mult) embeddings = words_embeddings + position_embeddings else: embeddings = words_embeddings @@ -154,10 +153,8 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): # Dropout. embeddings = self.embedding_dropout(embeddings) - - if self.use_mup: - with torch.no_grad(): - embeddings.mul_(self.mup_embedding_mult) + # Y_emb = m_emb * embed(X) + embeddings = torch.mul(embeddings, self.mup_embedding_multiplier) return embeddings diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 0d14806ac..8963e10e4 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -162,25 +162,6 @@ def __init__( self.weight, init_method, partition_dim=0, stride=1 ) - def mup_reinitialize_weights(self, neox_args): - if neox_args.use_cpu_initialization: - _initialize_affine_weight_cpu( - neox_args, - self.weight, - self.num_embeddings, - self.embedding_dim, - self.num_embeddings_per_partition, - 0, - partial(self.init_method, use_mup=True), - ) - else: - _initialize_affine_weight_gpu( - self.weight, - partial(self.init_method, use_mup=True), - partition_dim=0, - stride=1, - ) - def forward(self, input_): if self.model_parallel_size > 1: # Build the mask. @@ -292,25 +273,6 @@ def __init__( self._k_len_cached = None self._rel_pos_bucket_cached = None - def mup_reinitialize_weights(self, neox_args): - if self.use_cpu_initialization: - _initialize_affine_weight_cpu( - neox_args, - self.weight, - self.num_buckets, - self.heads, - self.num_heads_per_partition, - partition_dim=1, - init_method=partial(self.init_method, use_mup=True), - ) - else: - _initialize_affine_weight_gpu( - self.weight, - partial(self.init_method, use_mup=True), - partition_dim=1, - stride=1, - ) - @staticmethod def get_heads_range(global_n_heads, rank, world_size): per_partition_n_heads = divide(global_n_heads, world_size) @@ -415,7 +377,6 @@ def __init__( skip_bias_add=False, MOE=False, MoE_mp_size=1, - mup_rescale_parameters=False, ): super(ColumnParallelLinear, self).__init__() @@ -429,8 +390,6 @@ def __init__( self.skip_bias_add = skip_bias_add self.init_method = init_method self.stride = stride - self.mup_rescale_parameters = mup_rescale_parameters - self.use_mup = neox_args.use_mup # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result @@ -492,56 +451,6 @@ def __init__( else: self.register_parameter("bias", None) - # Copied from Mup - def width_mult(self): - assert hasattr(self.weight, "infshape"), ( - "Please call set_base_shapes(...). If using torch.nn.DataParallel, " - "switch to distributed training with " - "torch.nn.parallel.DistributedDataParallel instead" - ) - return self.weight.infshape.width_mult() - - # Copied from Mup - def _rescale_parameters(self): - """Rescale parameters to convert SP initialization to μP initialization. - Warning: This method is NOT idempotent and should be called only once - unless you know what you are doing. - """ - if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params: - raise RuntimeError( - "`_rescale_parameters` has been called once before already. " - "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" - "If you called `set_base_shapes` on a model loaded from a checkpoint, " - "or just want to re-set the base shapes of an existing model, " - "make sure to set the flag `rescale_params=False`.\n" - "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call." - ) - if self.bias is not None: - self.bias.data *= self.width_mult() ** 0.5 - self.weight.data *= self.width_mult() ** 0.5 - self._has_rescaled_params = True - - def mup_reinitialize_weights(self, neox_args): - if neox_args.use_cpu_initialization: - self.master_weight = _initialize_affine_weight_cpu( - neox_args, - self.weight, - self.output_size, - self.input_size, - self.output_size_per_partition, - 0, - partial(self.init_method, use_mup=True), - stride=self.stride, - return_master_weight=keep_master_weight_for_test, - ) - else: - _initialize_affine_weight_gpu( - self.weight, - partial(self.init_method, use_mup=True), - partition_dim=0, - stride=self.stride, - ) - def set_parallel_output(self, value: bool): assert isinstance(value, bool) self.gather_output = ( @@ -549,8 +458,7 @@ def set_parallel_output(self, value: bool): ) # if gather_output is True, parallel output is False, so we set the opposite def forward(self, input_): - if self.use_mup and self.mup_rescale_parameters: - input_ /= self.width_mult() + # Set up backprop all-reduce. input_parallel = copy_to_model_parallel_region(input_) # Matrix multiply. @@ -610,7 +518,6 @@ def __init__( MOE=False, MoE_mp_size=1, parallel_output=False, - mup_rescale_parameters=False, ): super(RowParallelLinear, self).__init__() @@ -626,8 +533,6 @@ def __init__( self.init_method = init_method self.stride = stride self.keep_master_weight_for_test = keep_master_weight_for_test - self.mup_rescale_parameters = mup_rescale_parameters - self.use_mup = neox_args.use_mup # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result @@ -683,63 +588,11 @@ def __init__( else: self.register_parameter("bias", None) - # Copied from Mup - def width_mult(self): - assert hasattr(self.weight, "infshape"), ( - "Please call set_base_shapes(...). If using torch.nn.DataParallel, " - "switch to distributed training with " - "torch.nn.parallel.DistributedDataParallel instead" - ) - return self.weight.infshape.width_mult() - - # Copied from Mup - def _rescale_parameters(self): - """Rescale parameters to convert SP initialization to μP initialization. - Warning: This method is NOT idempotent and should be called only once - unless you know what you are doing. - """ - if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params: - raise RuntimeError( - "`_rescale_parameters` has been called once before already. " - "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" - "If you called `set_base_shapes` on a model loaded from a checkpoint, " - "or just want to re-set the base shapes of an existing model, " - "make sure to set the flag `rescale_params=False`.\n" - "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call." - ) - if self.bias is not None: - self.bias.data *= self.width_mult() ** 0.5 - self.weight.data *= self.width_mult() ** 0.5 - self._has_rescaled_params = True - - def mup_reinitialize_weights(self, neox_args): - if neox_args.use_cpu_initialization: - self.master_weight = _initialize_affine_weight_cpu( - neox_args, - self.weight, - self.output_size, - self.input_size, - self.input_size_per_partition, - 1, - partial(self.init_method, use_mup=True), - stride=self.stride, - return_master_weight=self.keep_master_weight_for_test, - ) - else: - _initialize_affine_weight_gpu( - self.weight, - partial(self.init_method, use_mup=True), - partition_dim=1, - stride=self.stride, - ) - def set_parallel_output(self, parallel_output: bool): assert isinstance(parallel_output, bool) self.parallel_output = parallel_output def forward(self, input_): - if self.use_mup and self.mup_rescale_parameters: - input_ /= self.width_mult() # Set up backprop all-reduce. if self.input_is_parallel: input_parallel = input_ diff --git a/megatron/mup_substitute.py b/megatron/mup_substitute.py index e16a21589..9770c0765 100644 --- a/megatron/mup_substitute.py +++ b/megatron/mup_substitute.py @@ -2,6 +2,7 @@ Helper functions for performing coord check. """ import os +import gc from copy import copy from itertools import product @@ -9,64 +10,106 @@ import pandas as pd import torch import torch.nn.functional as F - -from mup import coord_check as mup_coord_check +import deepspeed +from megatron import print_rank_0 from megatron.training import train_step -def _get_coord_data( +def get_coord_data( neox_args, timers, - lr_scheduler, models, dataloader, - optcls, - nsteps=3, - dict_in_out=False, - flatten_input=False, - flatten_output=False, - output_name="loss", - lossfn="xent", - filter_module_by_name=None, - fix_data=True, - cuda=True, - nseeds=1, - output_fdict=None, - input_fdict=None, - param_fdict=None, - show_progress=True, - one_hot_target=False, + nsteps=10, + nseeds=2, ): - df = [] + df = { + "seed": [], + "step": [], + "word_embedding_act_abs_std": [], + "attn_output_act_abs_std": [], + "ffn_output_act_abs_std": [], + "output_logits_act_abs_std": [], + "width": [], + } + + df_mode = "mup" if neox_args.use_mup else "sp" + if neox_args.use_mup: + print_rank_0("muP Coord Check for mu Parameterization") + else: + print_rank_0("muP Coord Check for standard Parameterization") + + _df = None + df_path = os.path.join(neox_args.mup_save, f"df_{df_mode}.csv") + if (neox_args.mup_save is not None) and os.path.exists(df_path): + _df = pd.read_csv(df_path) - for i in range(nseeds): - torch.manual_seed(i) - for width, model in models.items(): - model = model() + with torch.no_grad(): + torch.cuda.empty_cache() + + for width, model_obj in models.items(): + for i in range(nseeds): + seed = (i + 1) * 100000 + torch.manual_seed(seed) + + model, optimizer, lr_scheduler = model_obj() model.train() - optimizer = optcls(model) + print_rank_0(f">>> muP Coord Check: Running Model with width: {width} on seed: {seed}") + print_rank_0(f">>> muP Coord Check: mup_width_multiplier set to {model.neox_args.mup_width_multiplier}") for step in range(nsteps + 1): + + word_embedding_act_abs_std_list = [] + attn_output_act_abs_std_list = [] + ffn_output_act_abs_std_list = [] + output_logits_act_abs_std_list = [] remove_hooks = [] - # add hooks + + def word_embedding_coord_check_hook(module, input, output): + with torch.no_grad(): + word_embedding_act_abs_std_list.append( + output.cpu().abs().std().item() + ) + + def attn_output_coord_check_hook(module, input, output): + with torch.no_grad(): + attn_output_act_abs_std_list.append( + output[0].cpu().abs().std().item() + ) + + def ffn_output_coord_check_hook(module, input, output): + with torch.no_grad(): + ffn_output_act_abs_std_list.append( + output[0].cpu().abs().std().item() + ) + + def output_logits_coord_check_hook(module, input, output): + with torch.no_grad(): + output_logits_act_abs_std_list.append( + output[0].cpu().abs().std().item() + ) + for name, module in model.named_modules(): - if filter_module_by_name and not filter_module_by_name(name): - continue - remove_hooks.append( - module.register_forward_hook( - mup_coord_check._record_coords( - df, - width, - name, - step + 1, - output_fdict=output_fdict, - input_fdict=input_fdict, - param_fdict=param_fdict, + if name.endswith(".word_embeddings"): + remove_hooks.append( + module.register_forward_hook( + word_embedding_coord_check_hook ) ) - ) + elif name.endswith(".attention.dense"): + remove_hooks.append( + module.register_forward_hook(attn_output_coord_check_hook) + ) + elif name.endswith(".mlp.dense_4h_to_h"): + remove_hooks.append( + module.register_forward_hook(ffn_output_coord_check_hook) + ) + elif name.endswith(".final_linear"): + remove_hooks.append( + module.register_forward_hook(output_logits_coord_check_hook) + ) # train for a step - loss_dict, skipped_iter = train_step( + train_step( neox_args=neox_args, timers=timers, data_iterator=dataloader, @@ -75,138 +118,52 @@ def _get_coord_data( lr_scheduler=lr_scheduler, ) + word_embedding_act_abs_std = None + attn_output_act_abs_std = None + ffn_output_act_abs_std = None + output_logits_act_abs_std = None + # remove hooks for handle in remove_hooks: handle.remove() - - import gc - - del model + word_embedding_act_abs_std = np.mean(word_embedding_act_abs_std_list) + attn_output_act_abs_std = np.mean(attn_output_act_abs_std_list) + ffn_output_act_abs_std = np.mean(ffn_output_act_abs_std_list) + output_logits_act_abs_std = np.mean(output_logits_act_abs_std_list) + + df["seed"].append(i) + df["step"].append(step) + df["word_embedding_act_abs_std"].append(word_embedding_act_abs_std) + df["attn_output_act_abs_std"].append(attn_output_act_abs_std) + df["ffn_output_act_abs_std"].append(ffn_output_act_abs_std) + df["output_logits_act_abs_std"].append(output_logits_act_abs_std) + df["width"].append(width) + + def del_obj_attrs(obj): + attributes = [ + attr for attr in vars(obj) if not callable(getattr(obj, attr)) + ] + for attr in attributes: + try: + delattr(obj, attr) + except: + pass + + def unlink_hp_params(lp_param_list): + for lp in lp_param_list: + lp._hp_mapping = None + return + + for i, _ in enumerate(optimizer.optimizer.param_groups): + unlink_hp_params(optimizer.bit16_groups[i]) + del_obj_attrs(optimizer) + model.destroy() + del optimizer gc.collect() + torch.cuda.empty_cache() + deepspeed.runtime.utils.empty_cache() - return pd.DataFrame(df) - + temp_df = pd.DataFrame(df) + temp_df.to_csv(os.path.join(neox_args.mup_save, f"df_{df_mode}.csv"), index=False) -def get_coord_data( - neox_args, - timers, - lr_scheduler, - models, - dataloader, - optimizer="sgd", - lr=None, - mup=True, - filter_trainable_by_name=None, - **kwargs -): - """Get coord data for coord check. - Train the models in `models` with data from `dataloader` and optimizer - specified by `optimizer` and `lr` for `nsteps` steps, and record coordinate - statistics specified by `output_fdict`, `input_fdict`, `param_fdict`. By - default, only `l1` is computed for output activations of each module. - This function wraps around `_get_coord_data`, with the main difference being - user can specify common optimizers via a more convenient interface. - Inputs: - models: - a dict of lazy models, where the keys are numbers indicating width. - Each entry of `models` is a function that instantiates a model given - nothing. - dataloader: - an iterator whose elements are either Huggingface style dicts, if - `dict_in_out` is True, or (input, label). If `fix_data` is True - (which is the default), then only the first element of `dataloader` - is used in a loop and the rest of `dataloder` is ignored. - optimizer: - a string in `['sgd', 'adam', 'adamw']`, with default being `'sgd'`. - lr: - learning rate. By default is 0.1 for `'sgd'` and 1e-3 for others. - mup: - If True, then use the optimizer from `mup.optim`; otherwise, use the - one from `torch.optim`. - filter_trainable_by_name: - a function that returns a bool given module names (from - `model.named_modules()`), or None. If not None, then only modules - whose name yields True will be trained. - nsteps: - number of steps to train the model - dict_in_out: - whether the data loader contains Huggingface-style dict input and - output. Default: False - flatten_input: - if not `dict_in_out`, reshape the input to be - `input.view(input.shape[0], -1)`. Typically used for testing MLPs. - flatten_output: - if not `dict_in_out`, reshape the label to be `label.view(-1, - input.shape[-1])`. - output_name: - if `dict_in_out`, this is the key for the loss value if the output - is a dict. If the output is not a dict, then we assume the first - element of the output is the loss. - lossfn: - loss function to use if not `dict_in_out`. Can be either a string from - [`xent`, 'mse', 'nll', 'l1'] or a python `callable` such that - `lossfn(output, target)` returns the loss value. Examples of valid - `callable`s are `F.cross_entropy`, `F.mse_loss`, etc, where `F` is - `torch.nn.functional`. Default: 'xent' - filter_module_by_name: - a function that returns a bool given module names (from - `model.named_modules()`), or None. If not None, then only modules - whose name yields True will be recorded. - cuda: - whether to use cuda or not. Default: True - nseeds: - number of times to repeat the training, each with different seeds. - output_fdict, input_fdict, param_fdict: - function dicts to be used in `_record_coords`. By default, only `l1` - is computed for output activations of each module. - show_progress: - show progress using tqdm. Default: True - one_hot_target: - convert target label into a one-hot vector. This typically is only - used for `'mse'` or `'l1'` losses in classification tasks. - Default: False - Output: - a pandas DataFrame containing recorded results. The column names are - `'width', 'module', 't'` as well as names of statistics recorded, such - as `'l1'` (see `FDICT` for other premade statistics that can be - collected). - - Breaking Changes: - In v1.0.0, when `lossfn=='mse'`, the target is automatically converted - to a one hot vector before loss computation. Starting in v1.1.0, this - behavior is turned off, and the user needs to explicitly turn on this - behavior by setting `one_hot_target=True`. - """ - if lr is None: - lr = 0.1 if optimizer == "sgd" else 1e-3 - if mup: - from mup.optim import MuAdam as Adam - from mup.optim import MuAdamW as AdamW - from mup.optim import MuSGD as SGD - else: - from torch.optim import SGD, Adam, AdamW - - def get_trainable(model): - params = model.parameters() - if filter_trainable_by_name is not None: - params = [] - for name, p in model.named_parameters(): - if filter_trainable_by_name(name): - params.append(p) - return params - - if optimizer == "sgd": - optcls = lambda model: SGD(get_trainable(model), lr=lr) - elif optimizer == "adam": - optcls = lambda model: Adam(get_trainable(model), lr=lr) - elif optimizer == "adamw": - optcls = lambda model: AdamW(get_trainable(model), lr=lr) - elif optimizer is None: - raise ValueError("optimizer should be sgd|adam|adamw or a custom function") - - data = _get_coord_data( - neox_args, timers, lr_scheduler, models, dataloader, optcls, **kwargs - ) - data["optimizer"] = optimizer - data["lr"] = lr - return data + return pd.DataFrame(df) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index bf6e3f3e8..c201d7ec2 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1153,6 +1153,27 @@ def validate_values(self): if not self.deepspeed: return False + if self.use_mup: + if self.mup_d_model_base is None: + logging.info("mup_d_model_base is required when use_mup is True") + return False + + if self.mup_lr is not None: + self.lr = self.mup_lr + logging.info(f"Overriding lr with mup_lr: {self.mup_lr}") + + if self.mup_std is not None: + self.init_method_std = self.mup_std + logging.info(f"Overriding init_method_std with mup_std: {self.mup_std}") + + if self.mup_hidden_size is not None: + self.hidden_size = self.mup_hidden_size + logging.info(f"Overriding hidden_size with mup_hidden_size: {self.mup_hidden_size}") + + if self.mup_width_multiplier is None: + self.mup_width_multiplier = self.hidden_size / self.mup_d_model_base + logging.info(f"Overriding mup_width_multiplier with hidden_size/mup_d_model_base: {self.mup_width_multiplier}") + # learning rate if self.lr is None: error_message = self.__class__.__name__ + ".validate_values() lr is None" diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 16d6456b4..4c9e9fa39 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -109,6 +109,7 @@ class NeoXArgsModel(NeoXArgsTemplate): hidden_size: int = None """ Transformer hidden size. + When using muP, this is d_model """ intermediate_size: int = None @@ -309,6 +310,7 @@ class NeoXArgsModel(NeoXArgsTemplate): init_method_std: float = 0.02 """ Standard deviation of the zero mean normal distribution used for weight initialization. + When using muP this is the base std """ apply_query_key_layer_scaling: bool = False @@ -522,6 +524,7 @@ class NeoXArgsOptimizer(NeoXArgsTemplate): lr: float = None """ Max Learning rate during training + When using muP, this is the base learning rate """ @@ -1140,7 +1143,30 @@ class NeoXArgsTraining(NeoXArgsTemplate): use_mup: bool = False """ - Whether to use Microsoft's Mup https://github.com/microsoft/mup + Whether to use muP + """ + + mup_save: str = None + """ + Path to save results when using muP + """ + + mup_lr: float = None + """ + An alias parameter for lr, + if not None will override lr + """ + + mup_std: float = None + """ + An alias parameter for init_method_std, + if not None will override init_method_std + """ + + mup_hidden_size: int = None + """ + An alias parameter for hidden_size, + if not None will override hidden_size """ coord_check: bool = False @@ -1148,45 +1174,45 @@ class NeoXArgsTraining(NeoXArgsTemplate): Whether to generate a "coord check" plot to verify mup's implementation in neox """ - save_base_shapes: bool = False + coord_check_nsteps: int = 10 """ - Whether to save base shapes for mup. This will save the shapes to the path specified in base-shapes-file. + Number of steps to do for the coordinate check """ - base_shapes_file: str = None + coord_check_nseeds: int = 5 """ - Path to the base shapes to save to/load from + Number of repetition for each size in coordinate check """ - mup_init_scale: float = 1.0 + save_base_shapes: bool = False """ - Initialization scale: All the parameters are multiplied by this value + Whether to save base shapes for mup. This will save the shapes to the path specified in base-shapes-file. """ - mup_attn_temp: float = 1.0 + base_shapes_file: str = None """ - Attention temperature: Reciprocal of the multiplier applied to the input to attention softmax + Path to the base shapes to save to/load from """ - mup_output_temp: float = 1.0 + mup_embedding_multiplier: float = 1.0 """ - Output temperature: Reciprocal of the multiplier applied to the input to softmax that - produces the distribution over output tokens. + Embedding output multiplier """ - mup_embedding_mult: float = 1.0 + mup_output_multiplier: float = 1.0 """ - Scalar by which we multiply the output of the embedding layer + Output logits multiplier """ - mup_rp_embedding_mult: float = 1.0 + mup_width_multiplier: float = None """ - Scalar by which we multiply vectors representing relative position + Manually set the layer width multiplier (d_model/d_model,base) """ - mup_width_scale: int = 2 + mup_d_model_base: int = 256 """ - What to scale width by when creating the delta model for mup + d_model,base + Proxy (base) model's layer width """ diff --git a/megatron/training.py b/megatron/training.py index 4ce5994a5..e182ad05e 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -24,6 +24,7 @@ import math import sys +import gc from contextlib import nullcontext import torch @@ -60,112 +61,108 @@ from pickle import dump import os +import seaborn as sns +import matplotlib.pyplot as plt + + +def plot_coord_data(df, graph_name_prefix, use_mup=True, save_path=None): + def _plot_data(df, activation, graph_name_prefix): + df = df.groupby(["step", "width"]).mean().reset_index() + sns.color_palette("magma") + sns.lineplot( + data=df, + x="width", + y=activation, + hue="step", + errorbar=None, + style="step", + marker="o", + dashes=False, + legend="full", + ) + plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0) + plt.tight_layout(pad=3.0) + plt.xlabel("Width") + plt.ylabel("Activation with {}".format("muP" if use_mup else "SP")) + plt.title(f"{activation}") + + file_path = f"{graph_name_prefix}-{activation}.png" + if save_path is not None: + file_path = os.path.join(save_path, file_path) + + plt.savefig(file_path) + plt.close() + + return 0 + + activation_list = [ + "word_embedding_act_abs_std", + "attn_output_act_abs_std", + "ffn_output_act_abs_std", + "output_logits_act_abs_std", + ] + """If distributed is initialized print only on rank 0.""" + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + for activation in activation_list: + _plot_data(df, activation, graph_name_prefix) + else: + for activation in activation_list: + _plot_data(df, activation, graph_name_prefix) -def mup_weights_reinit(neox_args, model): - def has_method(o, name): - return callable(getattr(o, name, None)) - - for layer in model.modules(): - # This normally would happen in set_base_shapes if we actually were able to use the MuReadout class - if hasattr(layer, "mup_rescale_parameters") and layer.mup_rescale_parameters: - layer._rescale_parameters() - - if has_method(layer, "mup_reinitialize_weights"): - layer.mup_reinitialize_weights(neox_args) - - -def save_base_shapes(neox_args, base_shapes, use_cache): - - # Instantiation of the base model fails in the init function (init_functions.py) because we haven't called set_base_shapes on it at this point, so disable it temporarily here - neox_args.use_mup = False - - base_model = GPT2ModelPipe( - neox_args=neox_args, - num_tokentypes=0, - parallel_output=True, - topology=mpu.get_topology(), - use_cache=use_cache, - ) - - if not neox_args.is_pipe_parallel: - base_model = base_model.to_sequential() - - try: - import mup - except ModuleNotFoundError: - print("Please install mup https://github.com/microsoft/mup") - raise Exception - - base_shapes = mup.get_shapes(base_model) - - del base_model - - old_hidden_size = neox_args.hidden_size - neox_args.hidden_size = neox_args.hidden_size * neox_args.mup_width_scale - - delta_model = GPT2ModelPipe( - neox_args=neox_args, - num_tokentypes=0, - parallel_output=True, - topology=mpu.get_topology(), - use_cache=use_cache, - ) - - if not neox_args.is_pipe_parallel: - delta_model = delta_model.to_sequential() - - delta_shapes = mup.get_shapes(delta_model) - - # change back - neox_args.use_mup = True - neox_args.hidden_size = old_hidden_size - - save_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}" - print(f"saving base shapes at {save_shapes}") - mup.make_base_shapes(base_shapes, delta_shapes, savefile=save_shapes) - print(f"base shapes saved...exiting") - sys.exit(1) + return 0 -def mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator): +def coord_check(neox_args, timers, train_data_iterator): from megatron.mup_substitute import get_coord_data - from mup.coord_check import plot_coord_data - def lazy_model(hidden_size): + if neox_args.mup_save is None: + print_rank_0("Must set mup_save") + sys.exit() + else: + os.makedirs(neox_args.mup_save, exist_ok=True) + + def lazy_model(hidden_size, attention_head, d_model_base=2**8): def gen(): - old_hidden_size = neox_args.hidden_size + neox_args.hidden_size = hidden_size + neox_args.num_attention_heads = attention_head + neox_args.mup_d_model_base = d_model_base + neox_args.mup_width_multiplier = hidden_size / d_model_base - model, optimizer, _ = setup_model_and_optimizer( + model, optimizer, lr_scheduler = setup_model_and_optimizer( neox_args=neox_args, use_cache=False ) - neox_args.hidden_size = old_hidden_size - - return model + return model, optimizer, lr_scheduler return gen models = {} + # Hidden size needs to be divisible by num attention heads #14 + for idx, hidden_size in enumerate([2**p for p in range(8, 11)]): + models[hidden_size] = lazy_model( + hidden_size, neox_args.num_attention_heads * (2**idx) + ) - # Hidden size needs to be divisible by num attention heads - for hidden_size in (neox_args.num_attention_heads * (2**p) for p in range(2, 9)): - models[hidden_size] = lazy_model(hidden_size) + df_mode = "mup" if neox_args.use_mup else "sp" - neox_args.use_mup = True - df_up = get_coord_data( - neox_args, timers, lr_scheduler, models, train_data_iterator, mup=True - ) - neox_args.use_mup = False - df_sp = get_coord_data( - neox_args, timers, lr_scheduler, models, train_data_iterator, mup=False + df = get_coord_data( + neox_args, + timers, + models, + train_data_iterator, + neox_args.coord_check_nsteps, + neox_args.coord_check_nseeds, ) - plot_coord_data(df_up, save_to=f"coord_check_up.{torch.distributed.get_rank()}.jpg") - plot_coord_data(df_sp, save_to=f"coord_check_sp.{torch.distributed.get_rank()}.jpg") + if neox_args.mup_save is not None: + plot_coord_data( + df, graph_name_prefix=f"coord_check_{df_mode}", use_mup=neox_args.use_mup, save_path=neox_args.mup_save + ) + print_rank_0("Saved coord check plots... exiting") - print_rank_0("Saved coord check plots... exiting") - sys.exit(1) + return 0 def pretrain(neox_args): @@ -190,6 +187,21 @@ def pretrain(neox_args): # Initialize and get arguments, timers, and Tensorboard writer. initialize_megatron(neox_args=neox_args) + if neox_args.coord_check: + print_rank_0("---- Do Coord Check ----") + # Data stuff + neox_args.iteration = 0 + timers("train/valid/test data iterators").start() + ( + train_data_iterator, + valid_data_iterator, + test_data_iterator, + ) = build_train_valid_test_data_iterators(neox_args=neox_args) + timers("train/valid/test data iterators").stop() + + coord_check(neox_args, timers, train_data_iterator) + sys.exit() + # Model, optimizer, and learning rate. timers("model and optimizer").start() model, optimizer, lr_scheduler = setup_model_and_optimizer( @@ -206,9 +218,6 @@ def pretrain(neox_args): ) = build_train_valid_test_data_iterators(neox_args=neox_args) timers("train/valid/test data iterators").stop() - if neox_args.use_mup and neox_args.coord_check: - mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator) - # Print setup timing. print_rank_0("done with setups ...") timers.log(["model and optimizer", "train/valid/test data iterators"]) @@ -423,11 +432,6 @@ def get_model(neox_args, use_cache=False): # Build model on cpu. print_rank_0("building GPT2 model ...") - # Temporarily disable mup so that the base model does not use the mup init functions before set_base_shapes is called below. - # If mup isn't being used anyways, this has no effect. - old_use_mup = neox_args.use_mup - neox_args.use_mup = False - with deepspeed.zero.Init( config_dict_or_path=neox_args.deepspeed_config ) if neox_args.zero_stage == 3 else nullcontext() as gs: @@ -463,25 +467,6 @@ def get_model(neox_args, use_cache=False): # Export PipeParallel model to nn.Sequential model to avoid the overhead of deepspeed's pipe parallel training model = model.to_sequential() - neox_args.use_mup = old_use_mup - - if neox_args.use_mup: - try: - import mup - except ModuleNotFoundError: - print("Please install mup https://github.com/microsoft/mup") - raise Exception - - base_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}" - - if neox_args.save_base_shapes: - save_base_shapes(neox_args, base_shapes, use_cache) - - mup.set_base_shapes(model, base_shapes) - - # Call the mup replacement init functions on the model now that set_base_shapes has given each weight a .infshape attribute - mup_weights_reinit(neox_args, model) - if neox_args.deepspeed: # DeepSpeed handles CUDA, FP16, and DDP components. return model @@ -499,7 +484,13 @@ def get_optimizer(model, neox_args): f"ERROR: Optimizer is None. Either set the optimizer dict in your config (if training) or set no_load_optim in your config (if inference)" ) exit() - # Build parameter groups (weight decay and non-decay). + + if neox_args.lr is not None: + neox_args.optimizer["params"]["lr"] = neox_args.lr + + # Build parameter groups for parameters that + # are affected by weight decay and non-decay or + # have adjustable and non-adjustable learning rate. param_groups = get_params_for_weight_decay_optimization(model, neox_args) print_rank_0( f'Configuring Optimizer type: {neox_args.optimizer_type} with params: {neox_args.optimizer["params"]}' @@ -581,49 +572,37 @@ def get_optimizer(model, neox_args): **neox_args.optimizer["params"], ) elif neox_args.optimizer_type.lower() == "adam": - # Use Adam - if neox_args.use_mup: + if neox_args.use_bnb_optimizer: try: - from mup import MuAdam + import bitsandbytes as bnb - adam_optimizer = MuAdam + adam_optimizer = bnb.optim.Adam8bit except ModuleNotFoundError: - print("Please install mup https://github.com/microsoft/mup") + print( + "Please install bitsandbytes following https://github.com/facebookresearch/bitsandbytes." + ) raise Exception else: - if neox_args.use_bnb_optimizer: - try: - import bitsandbytes as bnb - - adam_optimizer = bnb.optim.Adam8bit - except ModuleNotFoundError: - print( - "Please install bitsandbytes following https://github.com/facebookresearch/bitsandbytes." - ) - raise Exception - else: - try: - # default to apex as it's slightly faster - from apex.optimizers import FusedAdam as Adam - except ImportError: - # if apex isn't installed, use deepspeed's FusedAdam - print( - "WARNING: APEX not installed - defaulting to deepspeed's fused adam" - ) - from deepspeed.ops.adam import FusedAdam as Adam - adam_optimizer = Adam + try: + # default to apex as it's slightly faster + from apex.optimizers import FusedAdam as Adam + except ImportError: + # if apex isn't installed, use deepspeed's FusedAdam + print( + "WARNING: APEX not installed - defaulting to deepspeed's fused adam" + ) + # from deepspeed.ops.adam import FusedAdam as Adam + from torch.optim import Adam + adam_optimizer = Adam optimizer = adam_optimizer( param_groups, weight_decay=neox_args.weight_decay, **neox_args.optimizer["params"], ) elif neox_args.optimizer_type.lower() == "sgd": - try: - from mup import MuSGD - except ModuleNotFoundError: - print("Please install mup https://github.com/microsoft/mup") - raise Exception - optimizer = MuSGD( + from torch.optim import SGD + + optimizer = SGD( param_groups, weight_decay=neox_args.weight_decay, **neox_args.optimizer["params"], @@ -669,6 +648,7 @@ def get_learning_rate_scheduler(optimizer, neox_args): use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler, override_lr_scheduler=neox_args.override_lr_scheduler, use_mup=neox_args.use_mup, + mup_width_multiplier=neox_args.mup_width_multiplier, ) return lr_scheduler