diff --git a/configs/125M-moe.yml b/configs/125M-moe.yml new file mode 100644 index 000000000..3ad5c3647 --- /dev/null +++ b/configs/125M-moe.yml @@ -0,0 +1,98 @@ +# GPT-2 pretraining setup +{ + # Have 2 experts per layer (every 2 layers by default) + "num_experts": 2, + + # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages + # across the node boundaries ) + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + # model settings + "num_layers": 12, + "hidden_size": 768, + "num_attention_heads": 12, + "seq_length": 2048, + "max_position_embeddings": 2048, + "norm": "layernorm", + "pos_emb": "rotary", + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0006, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00006, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0.0, + "attention_dropout": 0.0, + + # precision settings + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 10, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, + + # networking + "hostfile": "/mock_path" +} diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 9a4fb0f42..e881f2229 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -43,6 +43,7 @@ bias_dropout_add_fused_inference, ) from megatron.model.utils import configure_sparse_attention +from deepspeed.moe.layer import MoE # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) @@ -82,7 +83,13 @@ class ParallelMLP(nn.Module): """ def __init__( - self, neox_args, init_method, output_layer_init_method, parallel_output=False + self, + neox_args, + init_method, + output_layer_init_method, + parallel_output=False, + MOE=False, + MoE_mp_size=1, ): super().__init__() @@ -104,6 +111,8 @@ def __init__( gather_output=False, init_method=init_method, skip_bias_add=True, + MOE=MOE, + MoE_mp_size=MoE_mp_size, ) ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim # Project back to h. @@ -113,8 +122,10 @@ def __init__( output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, - skip_bias_add=True, parallel_output=parallel_output, + skip_bias_add=True, + MOE=MOE, + MoE_mp_size=MoE_mp_size, ) def forward(self, hidden_states): @@ -817,6 +828,44 @@ def __init__( else: raise KeyError(neox_args.mlp_type) + self.num_experts = neox_args.num_experts + args = neox_args + if self.num_experts <= 1: + self.mlp = ParallelMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=self.gpt_j_residual, + ) + else: + from torch import distributed as dist + + if self.num_experts > dist.get_world_size(): + moe_mp_size = 1 + else: + moe_mp_size = dist.get_world_size() // self.num_experts + + self.mlp = MoE( + args.hidden_size, + ParallelMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=self.gpt_j_residual, + MOE=True, + MoE_mp_size=moe_mp_size, + ), + num_experts=self.num_experts, + ep_size=args.moe_expert_parallel_size, + k=args.topk, + use_residual=(args.mlp_type == "residual"), + capacity_factor=args.moe_train_capacity_factor, + eval_capacity_factor=args.moe_eval_capacity_factor, + min_capacity=args.moe_min_capacity, + drop_tokens=args.moe_token_dropping, + use_tutel=args.use_tutel, + ) + self.layer_past = None # used to cache k/v pairs in inference def _get_bias_dropout(self): @@ -913,10 +962,19 @@ def forward(self, x, attention_mask, layer_past=None): ) # output = x + mlp(ln2(x)) - mlp_output, mlp_bias = self.mlp( - self.post_attention_layernorm(attention_output) + layernorm_output = self.post_attention_layernorm(attention_output) + moe_loss = torch.tensor( + 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype + ) + mlp_bias = torch.tensor( + 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype ) + if self.num_experts == 1: + mlp_output, mlp_bias = self.mlp(layernorm_output) + else: + mlp_output, moe_loss, _ = self.mlp(layernorm_output) + with torch.enable_grad(): if self.mlp_type == "llama": # No dropout either @@ -930,7 +988,7 @@ def forward(self, x, attention_mask, layer_past=None): prob=self.hidden_dropout, ) - return output + return output, moe_loss class ParallelTransformerLayerPipe(ParallelTransformerLayer): @@ -942,7 +1000,7 @@ def forward(self, args): ), "ParallelTransformerLayerPipe expects 2 arguments - hidden_states and attention_mask" hidden_states, attention_mask = args # we are returning just [hidden_states, mask] - return super().forward(hidden_states, attention_mask), attention_mask + return super().forward(hidden_states, attention_mask)[0], attention_mask class ParallelLinearPipe(ParallelLinear): diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 6beac5ca2..19f9ccf3d 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -28,8 +28,12 @@ def get_params_for_weight_decay_optimization(module, neox_args): """Divide params into with-weight-decay and without-weight-decay groups. Layernorms and biases will have no weight decay but the rest will. """ - weight_decay_params = {"params": []} - no_weight_decay_params = {"params": [], "weight_decay": 0.0} + weight_decay_params = {"params": [], "name": "weight_decay_params"} + no_weight_decay_params = { + "params": [], + "weight_decay": 0.0, + "name": "no_weight_decay_params", + } for module_ in module.modules(): if any( [ diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 611d2adbf..beb5dd8b1 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -31,6 +31,7 @@ from .initialize import get_pipe_parallel_group from .initialize import get_pipe_parallel_rank from .initialize import get_pipe_parallel_world_size +from .initialize import get_tensor_model_parallel_world_size from .initialize import get_io_parallel_group from .initialize import initialize_model_parallel from .initialize import model_parallel_is_initialized diff --git a/megatron/mpu/initialize.py b/megatron/mpu/initialize.py index 325e46ba4..8fe9d8bbd 100644 --- a/megatron/mpu/initialize.py +++ b/megatron/mpu/initialize.py @@ -266,6 +266,12 @@ def get_pipe_parallel_world_size(): return torch.distributed.get_world_size(group=get_pipe_parallel_group()) +# Needed for MOE. True tensor parallelism todo. +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return torch.distributed.get_world_size() + + def destroy_model_parallel(): """Set the groups to none.""" global _MODEL_PARALLEL_GROUP diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 92edbd6eb..8bf1dd224 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -413,6 +413,8 @@ def __init__( stride=1, keep_master_weight_for_test=False, skip_bias_add=False, + MOE=False, + MoE_mp_size=1, mup_rescale_parameters=False, ): super(ColumnParallelLinear, self).__init__() @@ -422,7 +424,7 @@ def __init__( self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. - world_size = get_model_parallel_world_size() + world_size = MoE_mp_size if MOE else get_model_parallel_world_size() self.output_size_per_partition = divide(output_size, world_size) self.skip_bias_add = skip_bias_add self.init_method = init_method @@ -605,6 +607,8 @@ def __init__( stride=1, keep_master_weight_for_test=False, skip_bias_add=False, + MOE=False, + MoE_mp_size=1, parallel_output=False, mup_rescale_parameters=False, ): @@ -615,7 +619,7 @@ def __init__( self.output_size = output_size self.input_is_parallel = input_is_parallel # Divide the weight matrix along the last dimension. - world_size = get_model_parallel_world_size() + world_size = MoE_mp_size if MOE else get_model_parallel_world_size() self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add self.parallel_output = parallel_output diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index d293f7b6a..0fdb39178 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1032,7 +1032,13 @@ def calculate_derived(self): # Update 'is pipe parallel' flag # if we set pipe_parallel_size to 0 or 1, GPT2ModelPipe.to_sequential() is called, and we run training with # the sequential model without the PipelineModule wrapper to avoid the overhead it incurs - self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1) + self.update_value( + "is_pipe_parallel", self.pipe_parallel_size > 1 and self.num_experts == 1 + ) + assert not ( + (self.is_pipe_parallel or self.pipe_parallel_size > 1) + and self.num_experts > 1 + ), "MoE not supported with pipeline parallelism" # Attention config if self.attention_config is None: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 8b76c8b32..67ff29380 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -83,6 +83,11 @@ class NeoXArgsParallelism(NeoXArgsTemplate): according to pipeline parallel size. """ + expert_interval: int = 2 + """ + Have one MoE layer every expert_interval layers + """ + @dataclass class NeoXArgsModel(NeoXArgsTemplate): @@ -1154,3 +1159,48 @@ class NeoXArgsTextgen(NeoXArgsTemplate): """ Tasks to evaluate on using lm_eval_harness """ + + topk: int = 2 + """ + Activate top K experts in MoE + """ + + use_tutel: bool = False + """ + Use Tutel optimizations in MoE + """ + + num_experts: int = 1 + """ + Number of MoE experts + """ + + moe_train_capacity_factor: float = 1.0 + """ + The capacity of the expert at train time + """ + + moe_eval_capacity_factor: float = 1.0 + """ + The capacity of the expert at eval time + """ + + moe_min_capacity: int = 4 + """ + The minimum capacity per expert regardless of the capacity_factor + """ + + moe_token_dropping: bool = True + """ + Whether to drop tokens when exceeding capacity + """ + + create_moe_param_group: bool = True + """ + Whether to create a separate parameter group for MoE parameters + """ + + moe_expert_parallel_size: int = 1 + """ + Number of parallel experts in MoE + """ diff --git a/megatron/training.py b/megatron/training.py index e4e858574..ca7cc7b4b 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -478,6 +478,16 @@ def get_optimizer(model, neox_args): f'Configuring Optimizer type: {neox_args.optimizer_type} with params: {neox_args.optimizer["params"]}' ) + if neox_args.create_moe_param_group: + from deepspeed.moe.utils import ( + is_moe_param, + split_params_into_different_moe_groups_for_optimizer, + ) + + param_groups = split_params_into_different_moe_groups_for_optimizer( + param_groups + ) + # Add model parallel attribute if it is not set. for param_group in param_groups: for param in param_group["params"]: