From c509f6afd6c7cfb3e4693bd26943cf7d83b870b1 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Sat, 25 Nov 2023 21:07:41 -0800 Subject: [PATCH 1/2] Initial SP support --- configs/125M.yml | 1 + eval_tasks/eval_adapter.py | 3 +- megatron/checkpointing.py | 2 +- megatron/data/data_utils.py | 2 +- megatron/initialize.py | 54 ++-- megatron/model/transformer.py | 15 +- megatron/mpu/__init__.py | 2 +- megatron/mpu/initialize.py | 355 +++++++++++++++++++-------- megatron/neox_arguments/arguments.py | 10 +- megatron/neox_arguments/neox_args.py | 13 +- megatron/text_generation_utils.py | 38 +-- megatron/utils.py | 10 +- tests/model/test_model_generation.py | 4 +- 13 files changed, 349 insertions(+), 160 deletions(-) diff --git a/configs/125M.yml b/configs/125M.yml index 15a4b3b01..794c236b2 100644 --- a/configs/125M.yml +++ b/configs/125M.yml @@ -4,6 +4,7 @@ # across the node boundaries ) "pipe_parallel_size": 1, "model_parallel_size": 1, + "sequence_parallel_size": 1, # model settings "num_layers": 12, diff --git a/eval_tasks/eval_adapter.py b/eval_tasks/eval_adapter.py index e0a32797d..e824a0beb 100644 --- a/eval_tasks/eval_adapter.py +++ b/eval_tasks/eval_adapter.py @@ -72,13 +72,14 @@ def __init__(self, model, forward_step_fn, neox_args, batch_size=None): self.is_model_parallel = neox_args.model_parallel_size > 1 self.is_pipe_parallel = self.model.is_pipe_parallel self.is_data_parallel = self.model.is_data_parallel + self.is_sequence_parallel = self.model.is_sequence_parallel self.is_last_stage = ( True if not self.is_pipe_parallel else model.is_last_stage() ) # only the last stage of the pipeline model will receive the logits self.dp_world_size = mpu.get_data_parallel_world_size() self.dp_rank = mpu.get_data_parallel_rank() self.dp_group = mpu.get_data_parallel_group() - self.is_mp_rank_0 = mpu.get_model_parallel_rank() == 0 + self.is_mp_rank_0 = (not self.is_sequence_parallel and mpu.get_tensor_parallel_rank() == 0) or (self.is_sequence_parallel and mpu.get_sequence_parallel_rank() == 0) self._batch_size = batch_size or ( neox_args.batch_size * self.dp_world_size diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 8bcc01f3b..d310cbb71 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -141,7 +141,7 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None checkpoints_path, directory, "mp_rank_{:02d}".format( - mpu.get_model_parallel_rank() if mp_rank is None else mp_rank + mpu.get_tensor_parallel_rank() if mp_rank is None else mp_rank ), "model_optim_rng.pt", ) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 513dd0e21..40900a246 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -312,7 +312,7 @@ def build_train_valid_test_data_iterators(neox_args): pipe_load = True # Data loader only on rank 0 of each model parallel group. - if mpu.get_model_parallel_rank() == 0 and pipe_load: + if ((not neox_args.is_sequence_parallel and mpu.get_model_parallel_rank() == 0) or (neox_args.is_sequence_parallel and mpu.get_sequence_parallel_rank() == 0)) and pipe_load: # Number of train/valid/test samples. train_iters = neox_args.train_iters eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters diff --git a/megatron/initialize.py b/megatron/initialize.py index bc4032649..6c03d59cc 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -154,29 +154,39 @@ def _initialize_distributed(neox_args): ) # Setup 3D topology. - pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1 - mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1 - assert ( - neox_args.world_size % (pp * mp) == 0 - ), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}" - dp = neox_args.world_size // (pp * mp) - - from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology - - # this does pipe on the most outside, then data, then model. - # PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order. - topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp) - - # Offset base seeds for the interior pipeline stages. - # TODO: adjust last stage too once IO is improved. - stage_id = topo.get_coord(rank=torch.distributed.get_rank()).pipe - if 0 < stage_id < topo.get_dim("pipe") - 1: - offset = neox_args.seed + 1138 - neox_args.seed = offset + (stage_id * mp) + #pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1 + #mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1 + #assert ( + # neox_args.world_size % (pp * mp) == 0 + #), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}" + #sp = neox_args.sequence_parallel_size if neox_args.sequence_parallel_size >= 1 else 1 + #assert ( + # (sp > 1 and (pp == 1 and mp == 1)) or (sp == 1) + #), f"sp={sp} cannot be used along with pp>1 or mp>1" + #if sp > 1: + # dp = neox_args.world_size // (sp) + #else: + # dp = neox_args.world_size // (pp * mp) +# + #from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology, ProcessTopology +# + #if sp > 1: + # topo = ProcessTopology(axes=['data', 'sequence'], dims=[dp, sp]) + #else: + # # this does pipe on the most outside, then data, then model. + # # PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order. + # topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp) +# + # # Offset base seeds for the interior pipeline stages. + # # TODO: adjust last stage too once IO is improved. + # stage_id = topo.get_coord(rank=torch.distributed.get_rank()).pipe + # if 0 < stage_id < topo.get_dim("pipe") - 1: + # offset = neox_args.seed + 1138 + # neox_args.seed = offset + (stage_id * mp) # Set the model-parallel / data-parallel communicators. if device_count > 0: - if mpu.model_parallel_is_initialized(): + if mpu.tensor_parallel_is_initialized() or mpu.sequence_parallel_is_initialized: print( "_initialize_distributed() model parallel is already initialized", flush=True, @@ -184,7 +194,9 @@ def _initialize_distributed(neox_args): else: mpu.initialize_model_parallel( neox_args.model_parallel_size, - topology=topo, + neox_args.sequence_parallel_size, + neox_args.pipeline_parallel_size, + #topology=topo, fp32_allreduce=neox_args.fp32_allreduce, ) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 63f4122e2..a1ab46036 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -275,7 +275,7 @@ def __init__( self.attention_softmax_in_fp32 = True self.layer_number = layer_number # Per attention head and per partition values. - world_size = mpu.get_model_parallel_world_size() + world_size = mpu.get_tensor_parallel_world_size() self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( neox_args.hidden_size, neox_args.num_attention_heads @@ -310,7 +310,7 @@ def __init__( self.alibi_embed = AliBi( neox_args.num_attention_heads, neox_args.model_parallel_size, - mpu.get_model_parallel_rank(), + mpu.get_tensor_parallel_rank(), ) # TODO: this arg shouldn't need to be passed in - get from neox_args @@ -338,6 +338,7 @@ def __init__( self.attention_type = neox_args.attention_config[layer_number] self.use_flash_attention = self.attention_type == "flash" + self.use_ds_ulysses_attention = self.attention_type == "ulysses" self.sparse = self.attention_type not in ("global", "flash") if self.sparse: self.sparse_attn = configure_sparse_attention( @@ -360,6 +361,13 @@ def __init__( self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton self.flash_qkv_fn = flash_attn_varlen_qkvpacked_func self.flash_kv_fn = flash_attn_varlen_kvpacked_func + elif self.use_ds_ulysses_attention: + print('USING ULYSSES') + try: + from deepspeed.sequence.layer import DistributedAttention + self.ds_ulysses_attention_fn = DistributedAttention(self.attention, mpu.get_sequence_parallel_group()) + except ImportError as e: + print(f'Error. You passed a gpt-neox ulysses config but DeepSpeed ulysses could not be imported with the error: {e}') else: self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.fp16, @@ -598,6 +606,7 @@ def flash_attention(self, query_layer, key_layer, value_layer): return matmul_result + def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): # TODO: sparse attn dropout? # TODO: pad to block size @@ -688,6 +697,8 @@ def forward(self, hidden_states, attention_mask, layer_past=None): if self.use_flash_attention: context_layer = self.flash_attention(query_layer, key_layer, value_layer) + elif self.use_ds_ulysses_attention: + context_layer = self.ds_ulysses_attention_fn(query_layer, key_layer, value_layer, attention_mask) elif not self.sparse: context_layer = self.attention( query_layer, key_layer, value_layer, layer_past, attention_mask diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 611d2adbf..06450d172 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -18,7 +18,7 @@ from .data import broadcast_data -from .initialize import is_unitialized +from .initialize import is_uninitialized from .initialize import destroy_model_parallel from .initialize import get_data_parallel_group from .initialize import get_data_parallel_rank diff --git a/megatron/mpu/initialize.py b/megatron/mpu/initialize.py index 325e46ba4..510168c2e 100644 --- a/megatron/mpu/initialize.py +++ b/megatron/mpu/initialize.py @@ -23,11 +23,13 @@ from .utils import ensure_divisibility # Model parallel group that the current rank belongs to. -_MODEL_PARALLEL_GROUP = None +_TENSOR_PARALLEL_GROUP = None # Data parallel group that the current rank belongs to. _DATA_PARALLEL_GROUP = None # Pipeline parallel group that the current rank belongs to. _PIPE_PARALLEL_GROUP = None +# Sequence parallel group that the current rank belongs to. +_SEQUENCE_PARALLEL_GROUP = None # A group used to sync during the IO process. Usually this is data_parallel_group(), # but with pipeline parallelism it must also involve the last stage (which is not in the @@ -35,27 +37,35 @@ _IO_PARALLEL_GROUP = None # These values enable us to change the mpu sizes on the fly. -_MPU_WORLD_SIZE = None -_MPU_RANK = None +_MPU_OR_SPU_WORLD_SIZE = None +_MPU_OR_SPU_RANK = None + +# These values enable us to change the mpu sizes on the fly. +_MPU_OR_SPU_WORLD_SIZE = None +_MPU_OR_SPU_RANK = None # Used to query 3D topology -_MPU_TOPOLOGY = None +_MPU_OR_SPU_TOPOLOGY = None # Get fp32_allreduce flag _FP32_ALLREDUCE = None +# Are we using deepspeed sequence parallelism +_IS_SEQUENCE_PARALLEL = None -def is_unitialized(): +def is_uninitialized(): """Useful for code segments that may be accessed with or without mpu initialization""" return _DATA_PARALLEL_GROUP is None -def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce=False): +def initialize_model_parallel(tensor_model_parallel_size, sequence_model_parallel_size, pipeline_model_parallel_size, neox_args, fp32_allreduce=False): """ Initialize model data parallel groups. Arguments: - model_parallel_size: number of GPUs used to parallelize model. + tensor_model_parallel_size: number of GPUs used to parallelize model. + pipeline_model_parallel_size: number of GPUs used to parallelize model. + sequence_model_parallel_size: number of GPUs used to parallelize model. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model. The present function will @@ -69,19 +79,69 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ - if torch.distributed.get_rank() == 0: - print("> initializing model parallel with size {}".format(model_parallel_size)) + # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() - if world_size < model_parallel_size: - raise ValueError("world size cannot be smaller than model parallel size") - ensure_divisibility(world_size, model_parallel_size) + global _IS_SEQUENCE_PARALLEL + _IS_SEQUENCE_PARALLEL = sequence_model_parallel_size > 1 + + if _IS_SEQUENCE_PARALLEL: + assert tensor_model_parallel_size == 1 and pipeline_model_parallel_size == 1, \ + 'DeepSpeed\'s sequence parallel does not work with tensor parallel or pipeline parallel' + + if torch.distributed.get_rank() == 0: + if _IS_SEQUENCE_PARALLEL: + print("> initializing sequence model parallel with size {}".format(sequence_model_parallel_size)) + else: + print("> initializing tensor model parallel with size {}, and pipeline model parallel with size {}".format(tensor_model_parallel_size, pipeline_model_parallel_size)) + + if _IS_SEQUENCE_PARALLEL: + # Ensure none of the parallel sizes are too large + if world_size < sequence_model_parallel_size: + raise ValueError("world size cannot be smaller than sequence model parallel size") + # Ensure each axis is divisible by world size + ensure_divisibility(world_size, sequence_model_parallel_size) + data_parallel_size = world_size // sequence_model_parallel_size + else: + # Ensure none of the parallel sizes are too large + if world_size < tensor_model_parallel_size * pipeline_model_parallel_size: + raise ValueError("world size cannot be smaller than tensor_model_parallel_size * pipeline_model_parallel_size") + # Ensure each axis is divisible by world size + ensure_divisibility(world_size, tensor_model_parallel_size) + ensure_divisibility(world_size, pipeline_model_parallel_size) + data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size) + + + # Set up the topology + from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology, ProcessTopology + + if _IS_SEQUENCE_PARALLEL: + topology = ProcessTopology(axes=['data', 'sequence'], dims=[data_parallel_size, sequence_model_parallel_size]) + else: + # this does pipe on the most outside, then data, then model. + # PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order. + topology = PipeModelDataParallelTopology(num_pp=pipeline_model_parallel_size, num_mp=tensor_model_parallel_size, num_dp=data_parallel_size) + + # Offset base seeds for the interior pipeline stages. + # TODO: adjust last stage too once IO is improved. + stage_id = topology.get_coord(rank=torch.distributed.get_rank()).pipe + if 0 < stage_id < topology.get_dim("pipe") - 1: + offset = neox_args.seed + 1138 + neox_args.seed = offset + (stage_id * mp) + + #data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size * sequence_parallel_size) + #sequence_data_parallel_size: int = sequence_parallel_size * data_parallel_size + + #assert (sequence_parallel_size > 1 and tensor_model_parallel_size == 1) or (sequence_parallel_size == 1), "sequence parallelism not yet supported with pipeline or tensor parallelism" + #num_sequence_parallel_groups: int = world_size // sequence_parallel_size + #num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size + rank = torch.distributed.get_rank() - global _MPU_TOPOLOGY + global _MPU_OR_SPU_TOPOLOGY if topology: - _MPU_TOPOLOGY = topology + _MPU_OR_SPU_TOPOLOGY = topology # Build the data parallel groups. global _DATA_PARALLEL_GROUP @@ -90,29 +150,36 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce for dp_group in topology.get_axis_comm_lists("data"): group = torch.distributed.new_group(ranks=dp_group) if rank == 0: - print(f"MPU DP:", dp_group) + print(f"DP Group:", dp_group) if rank in dp_group: _DATA_PARALLEL_GROUP = group else: - for i in range(model_parallel_size): - ranks = range(i, world_size, model_parallel_size) - group = torch.distributed.new_group(ranks) - if i == (rank % model_parallel_size): - _DATA_PARALLEL_GROUP = group + if _IS_SEQUENCE_PARALLEL: + for i in range(sequence_model_parallel_size): + ranks = range(i, world_size, sequence_model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank % sequence_model_parallel_size): + _DATA_PARALLEL_GROUP = group + else: + for i in range(tensor_model_parallel_size): + ranks = range(i, world_size, tensor_model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank % tensor_model_parallel_size): + _DATA_PARALLEL_GROUP = group # Build pipeline parallel group - if topology is not None: + if topology is not None and not _IS_SEQUENCE_PARALLEL: global _PIPE_PARALLEL_GROUP for pp_group in topology.get_axis_comm_lists("pipe"): group = torch.distributed.new_group(ranks=pp_group) if rank == 0: - print(f"MPU PP:", pp_group) + print(f"PP Group:", pp_group) if rank in pp_group: _PIPE_PARALLEL_GROUP = group - # Build IO group + # Build IO group for PP global _IO_PARALLEL_GROUP - if topology and topology.get_dim("pipe") > 1: + if topology and not _IS_SEQUENCE_PARALLEL and topology.get_dim("pipe") > 1: io_stages = [0, topology.get_dim("pipe") - 1] io_group = [] for stage in io_stages: @@ -126,57 +193,97 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce _IO_PARALLEL_GROUP = get_data_parallel_group() # Build the model parallel groups. - global _MODEL_PARALLEL_GROUP - assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" + global _SEQUENCE_PARALLEL_GROUP + assert _SEQUENCE_PARALLEL_GROUP is None, "sequence model parallel group is already initialized" + global _TENSOR_PARALLEL_GROUP + assert _TENSOR_PARALLEL_GROUP is None, "tensor model parallel group is already initialized" if topology: - # Short circuit case without model parallelism. + # Short circuit case without tensor/sequence parallelism. # TODO: it would be nice to avoid this branching case? - if model_parallel_size == 1: - for group_rank in range(world_size): - group = torch.distributed.new_group(ranks=[group_rank]) - if rank == 0: - print(f"MPU MP:", [group_rank]) - if rank == group_rank: - _MODEL_PARALLEL_GROUP = group + if _IS_SEQUENCE_PARALLEL: + if sequence_model_parallel_size == 1: + for group_rank in range(world_size): + group = torch.distributed.new_group(ranks=[group_rank]) + if rank == 0: + print(f"SP Group:", [group_rank]) + if rank == group_rank: + _SEQUENCE_PARALLEL_GROUP = group + return + else: + if tensor_model_parallel_size == 1: + for group_rank in range(world_size): + group = torch.distributed.new_group(ranks=[group_rank]) + if rank == 0: + print(f"TP Group:", [group_rank]) + if rank == group_rank: + _TENSOR_PARALLEL_GROUP = group return - for mp_group in topology.get_axis_comm_lists("model"): - group = torch.distributed.new_group(ranks=mp_group) - if rank == 0: - print(f"MPU MP:", mp_group) - if rank in mp_group: - _MODEL_PARALLEL_GROUP = group - + if _IS_SEQUENCE_PARALLEL: + for sp_group in topology.get_axis_comm_lists("sequence"): + group = torch.distributed.new_group(ranks=sp_group) + if rank == 0: + print(f"SP Group:", sp_group) + if rank in sp_group: + _SEQUENCE_PARALLEL_GROUP = group + else: + for tp_group in topology.get_axis_comm_lists("model"): + group = torch.distributed.new_group(ranks=tp_group) + if rank == 0: + print(f"TP Group:", tp_group) + if rank in tp_group: + _TENSOR_PARALLEL_GROUP = group else: - for i in range(world_size // model_parallel_size): - ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) - group = torch.distributed.new_group(ranks) - if i == (rank // model_parallel_size): - _MODEL_PARALLEL_GROUP = group + if _IS_SEQUENCE_PARALLEL: + for i in range(world_size // sequence_model_parallel_size): + ranks = range(i * sequence_model_parallel_size, (i + 1) * sequence_model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank // sequence_model_parallel_size): + _SEQUENCE_PARALLEL_GROUP = group + else: + for i in range(world_size // tensor_model_parallel_size): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank // tensor_model_parallel_size): + _TENSOR_PARALLEL_GROUP = group global _FP32_ALLREDUCE assert _FP32_ALLREDUCE is None, "fp32_allreduce is already initialized" _FP32_ALLREDUCE = fp32_allreduce +# Check if initialized +def tensor_parallel_is_initialized(): + """Check if model and data parallel groups are initialized.""" + if _TENSOR_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: + return False + return True -def model_parallel_is_initialized(): +def sequence_parallel_is_initialized(): """Check if model and data parallel groups are initialized.""" - if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: + if _SEQUENCE_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: return False return True -def get_model_parallel_group(): - """Get the model parallel group the caller rank belongs to.""" - assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" - return _MODEL_PARALLEL_GROUP +# Get the parallel group +def get_sequence_model_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + return _SEQUENCE_PARALLEL_GROUP +def get_tensor_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + assert _TENSOR_PARALLEL_GROUP is not None, "model parallel group is not initialized" + return _TENSOR_PARALLEL_GROUP def get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" return _DATA_PARALLEL_GROUP +def get_pipe_model_parallel_group(): + """Get the pipe parallel group the caller rank belongs to.""" + assert _PIPE_PARALLEL_GROUP is not None, "data parallel group is not initialized" + return _PIPE_PARALLEL_GROUP def get_io_parallel_group(): """Get the IO parallel group the caller rank belongs to.""" @@ -184,104 +291,140 @@ def get_io_parallel_group(): return _IO_PARALLEL_GROUP -def set_model_parallel_world_size(world_size): +# Set the parallel world size +def set_tensor_model_parallel_world_size(world_size): """Set the model parallel size""" - global _MPU_WORLD_SIZE - _MPU_WORLD_SIZE = world_size + global _MPU_OR_SPU_WORLD_SIZE + _MPU_OR_SPU_WORLD_SIZE = world_size +def set_sequence_model_parallel_world_size(world_size): + """Set the model parallel size""" + global _MPU_OR_SPU_WORLD_SIZE + _MPU_OR_SPU_WORLD_SIZE = world_size + + +# Get the parallel world size +def get_tensor_model_parallel_world_size(): + """Return world size for the model parallel group.""" + global _MPU_OR_SPU_WORLD_SIZE + if _MPU_OR_SPU_WORLD_SIZE is not None: + return _MPU_OR_SPU_WORLD_SIZE + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) -def get_model_parallel_world_size(): +def get_sequence_model_parallel_world_size(): """Return world size for the model parallel group.""" - global _MPU_WORLD_SIZE - if _MPU_WORLD_SIZE is not None: - return _MPU_WORLD_SIZE - return torch.distributed.get_world_size(group=get_model_parallel_group()) + global _MPU_OR_SPU_WORLD_SIZE + if _MPU_OR_SPU_WORLD_SIZE is not None: + return _MPU_OR_SPU_WORLD_SIZE + return torch.distributed.get_world_size(group=get_sequence_model_parallel_group()) -def set_model_parallel_rank(rank): - """Set model parallel rank.""" - global _MPU_RANK - _MPU_RANK = rank +# Set the parallel rank +def set_tensor_model_parallel_rank(rank): + """Set tensor parallel rank.""" + global _MPU_OR_SPU_RANK + _MPU_OR_SPU_RANK = rank +def set_sequence_model_parallel_rank(rank): + """Set sequence parallel rank.""" + global _MPU_OR_SPU_RANK + _MPU_OR_SPU_RANK = rank -def get_model_parallel_rank(): + +# Get the parallel rank +def get_tensor_model_parallel_rank(): + """Return my rank for the model parallel group.""" + global _MPU_OR_SPU_RANK + if _MPU_OR_SPU_RANK is not None: + return _MPU_OR_SPU_RANK + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + +def get_sequence_model_parallel_rank(): """Return my rank for the model parallel group.""" - global _MPU_RANK - if _MPU_RANK is not None: - return _MPU_RANK - return torch.distributed.get_rank(group=get_model_parallel_group()) + global _MPU_OR_SPU_RANK + if _MPU_OR_SPU_RANK is not None: + return _MPU_OR_SPU_RANK + return torch.distributed.get_rank(group=get_sequence_model_parallel_group()) +def get_pipe_model_parallel_rank(): + """Return my rank for the pipe parallel group.""" + return torch.distributed.get_rank(group=get_pipe_model_parallel_group()) -def get_model_parallel_src_rank(): + +# Get the src rank +def get_tensor_model_parallel_src_rank(): """Calculate the global rank corresponding to a local rank zero in the model parallel group.""" global_rank = torch.distributed.get_rank() - local_world_size = get_model_parallel_world_size() + local_world_size = get_tensor_model_parallel_world_size() return (global_rank // local_world_size) * local_world_size +def get_sequence_model_parallel_src_rank(): + """Calculate the global rank corresponding to a local rank zero + in the model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_sequence_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size def get_data_parallel_src_rank(): """Calculate the global rank corresponding to a local rank zero in the data parallel group.""" global_rank = torch.distributed.get_rank() topo = get_topology() - if topo is None: - # we are just using model parallel - return global_rank % get_model_parallel_world_size() + if _IS_SEQUENCE_PARALLEL: + # we are just using tensor parallel + return global_rank % get_tensor_model_parallel_world_size() else: - # We are using pipeline parallel - d = topo.get_axis_comm_lists("data") - for l in d: - if global_rank in l: - return l[0] - - + if topo is None: + # we are just using tensor parallel + return global_rank % get_tensor_model_parallel_world_size() + else: + # We are using pipeline parallel + d = topo.get_axis_comm_lists("data") + for l in d: + if global_rank in l: + return l[0] + +# Get the world size def get_data_parallel_world_size(): """Return world size for the data parallel group.""" return torch.distributed.get_world_size(group=get_data_parallel_group()) +def get_pipe_parallel_world_size(): + """Return world size for the pipe parallel group.""" + return torch.distributed.get_world_size(group=get_pipe_model_parallel_group()) def get_data_parallel_rank(): """Return my rank for the data parallel group.""" return torch.distributed.get_rank(group=get_data_parallel_group()) +def is_sequence_parallel(): + """Return whether sequence parallelism is used""" + return _IS_SEQUENCE_PARALLEL +# Get topology def get_topology(): - return _MPU_TOPOLOGY - - -def get_pipe_parallel_group(): - """Get the pipe parallel group the caller rank belongs to.""" - assert _PIPE_PARALLEL_GROUP is not None, "data parallel group is not initialized" - return _PIPE_PARALLEL_GROUP - - -def get_pipe_parallel_rank(): - """Return my rank for the pipe parallel group.""" - return torch.distributed.get_rank(group=get_pipe_parallel_group()) - - -def get_pipe_parallel_world_size(): - """Return world size for the pipe parallel group.""" - return torch.distributed.get_world_size(group=get_pipe_parallel_group()) + return _MPU_OR_SPU_TOPOLOGY def destroy_model_parallel(): """Set the groups to none.""" - global _MODEL_PARALLEL_GROUP - _MODEL_PARALLEL_GROUP = None + global _TENSOR_PARALLEL_GROUP + _TENSOR_PARALLEL_GROUP = None global _DATA_PARALLEL_GROUP _DATA_PARALLEL_GROUP = None global _PIPE_PARALLEL_GROUP _PIPE_PARALLEL_GROUP = None + global _SEQUENCE_PARALLEL_GROUP + _SEQUENCE_PARALLEL_GROUP = None global _IO_PARALLEL_GROUP _IO_PARALLEL_GROUP = None - global _MPU_WORLD_SIZE - global _MPU_RANK - _MPU_WORLD_SIZE = None - _MPU_RANK = None - global _MPU_TOPOLOGY - _MPU_TOPOLOGY = None + global _MPU_OR_SPU_WORLD_SIZE + global _MPU_OR_SPU_RANK + _MPU_OR_SPU_WORLD_SIZE = None + _MPU_OR_SPU_RANK = None + global _MPU_OR_SPU_TOPOLOGY + _MPU_OR_SPU_TOPOLOGY = None global _FP32_ALLREDUCE _FP32_ALLREDUCE = None diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 09e45f10c..c19015289 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -742,8 +742,8 @@ def configure_distributed_args(self): if self.rank == 0: print( self.__class__.__name__ - + ".configure_distributed_args() using world size: {} and model-parallel size: {} ".format( - self.world_size, self.model_parallel_size + + ".configure_distributed_args() using world size: {}, pipe-parallel size: {}, sequence-parallel size: {}, and model-parallel size: {} ".format( + self.world_size, self.pipe_parallel_size, self.sequence_parallel_size, self.model_parallel_size ), flush=True, ) @@ -847,6 +847,9 @@ def calculate_derived(self): pp_size = pp_size if pp_size >= 1 else 1 mp_size = self.model_parallel_size mp_size = mp_size if mp_size >= 1 else 1 + sp_size = self.sequence_parallel_size + sp_size = sp_size if sp_size >= 1 else 1 + self.update_value("sequence_parallel_size", sp_size) self.update_value("model_parallel_size", mp_size) # pp_size and mp_size are only used here to compute dp world size and nowhere else. @@ -1023,6 +1026,9 @@ def calculate_derived(self): # the sequential model without the PipelineModule wrapper to avoid the overhead it incurs self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1) + # Update 'is sequence parallel' flag + self.update_value("is_sequence_parallel", self.sequence_parallel_size > 1) + # Attention config if self.attention_config is None: self.update_value("attention_config", [[["global"], self.num_layers]]) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 957960832..06066df7b 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -65,6 +65,11 @@ class NeoXArgsParallelism(NeoXArgsTemplate): Size of the model parallelism. """ + sequence_parallel_size: int = 1 + """ + Size of the model parallelism. + """ + pipe_partition_method: str = "type:transformer|mlp" """ method used to distribute model layers across pipeline stages. Choose from "parameters", which balances the number @@ -83,6 +88,12 @@ class NeoXArgsParallelism(NeoXArgsTemplate): according to pipeline parallel size. """ + is_sequence_parallel: bool = False + """ + flag to determine whether sequence parallelism is on - shouldn't be set by user, is automatically determined + according to sequence parallel size. + """ + @dataclass class NeoXArgsModel(NeoXArgsTemplate): @@ -801,7 +812,7 @@ class NeoXArgsTraining(NeoXArgsTemplate): s3_chunk_size: int = 104_857_600 """ The number of bytes in each file chunk when uploading to s3. Defaults to 100MiB. - """ + """ config_files: dict = None """ diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 5eb982384..6e147e2e5 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -28,7 +28,7 @@ from megatron import print_rank_0 from megatron import mpu -from megatron.utils import get_ltor_masks_and_position_ids, is_mp_rank_0 +from megatron.utils import get_ltor_masks_and_position_ids, is_tp_rank_0 def get_batch(neox_args, context_tokens: torch.Tensor): @@ -158,13 +158,15 @@ def forward_model(model, model_inputs, is_pipe_parallel=False) -> torch.Tensor: return logits -def broadcast_terminate_signal(terminate_runs: int): +def broadcast_terminate_signal(terminate_runs: int, neox_args): """Send signal to all workers to terminate if we've finished the process""" terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) + tp_sp_src_rank = mpu.get_sequence_parallel_src_rank() if neox_args.is_sequence_parallel else mpu.get_tensor_parallel_src_rank() + tp_sp_group = mpu.get_sequence_parallel_group() if neox_args.is_sequence_parallel else mpu.get_tensor_parallel_group() torch.distributed.broadcast( terminate_runs_tensor, - mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group(), + tp_sp_src_rank, + group=tp_sp_group, ) return terminate_runs_tensor[0].item() @@ -242,17 +244,19 @@ def stream_tokens( for i in range(0, len(stop_tokens)): stop_tokens[i] = torch.cuda.LongTensor(stop_tokens[i]) + tp_sp_src_rank = mpu.get_sequence_parallel_src_rank() if neox_args.is_sequence_parallel else mpu.get_tensor_parallel_src_rank() + tp_sp_group = mpu.get_sequence_parallel_group() if neox_args.is_sequence_parallel else mpu.get_tensor_parallel_group() # Make sure context tokens + start tokens are the same across all ranks token_generation_start_index = torch.cuda.LongTensor(context_lengths) torch.distributed.broadcast( context_tokens, - mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group(), + tp_sp_src_rank, + group=tp_sp_group, ) torch.distributed.broadcast( token_generation_start_index, - mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group(), + tp_sp_src_rank, + tp_sp_group, ) # get attention mask / position ids @@ -468,12 +472,12 @@ def generate_samples_from_prompt( "\nPlease give smaller context (e.g. half of the " "max sequence length)!", ) - if not is_mp_rank_0(): + if not is_tp_rank_0(): context_tokens = neox_args.tokenizer.tokenize("EMPTY TEXT") context_length = len(context_tokens) terminate_runs = 0 - terminate_runs = broadcast_terminate_signal(terminate_runs) + terminate_runs = broadcast_terminate_signal(terminate_runs, neox_args) if terminate_runs == 1: return generated_texts @@ -526,7 +530,7 @@ def generate_samples_from_prompt( generated_tokens = [] # this will happen if the first generated token is a stop token or eos token message = "WARNING: text generation did not start; try different batching or adjust parameters" - if is_mp_rank_0(): + if is_tp_rank_0(): data = { "context": raw_text, "text": generated_text, @@ -602,7 +606,7 @@ def generate_samples_input_from_file( "generate_samples_input_from_file() prompts loaded: {}".format(len(prompts)) ) - if is_mp_rank_0(): + if is_tp_rank_0(): if output_file is None: output_file = str(input_file) + ".output.jsonl" print_rank_0( @@ -624,7 +628,7 @@ def generate_samples_input_from_file( top_p=top_p, ) - if is_mp_rank_0(): + if is_tp_rank_0(): with open(output_file, "w") as f_out: for item in generated_texts: f_out.write(json.dumps(item) + "\n") @@ -689,7 +693,7 @@ def generate_samples_unconditional( top_p=top_p, ) - if is_mp_rank_0(): + if is_tp_rank_0(): if output_file is not None: with open(output_file, "w") as f_out: for item in generated_texts: @@ -737,7 +741,7 @@ def generate_samples_interactive( while True: model.module.clear_cache() # clear kv cache between batches - torch.distributed.barrier(group=mpu.get_model_parallel_group()) + torch.distributed.barrier(group=mpu.get_tensor_parallel_group()) terminate_runs = 0 if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: @@ -771,7 +775,7 @@ def generate_samples_interactive( context_tokens = neox_args.tokenizer.tokenize("EMPTY TEXT") context_length = len(context_tokens) - terminate_runs = broadcast_terminate_signal(terminate_runs) + terminate_runs = broadcast_terminate_signal(terminate_runs, neox_args) if terminate_runs == 1: return for ( @@ -791,7 +795,7 @@ def generate_samples_interactive( top_k=top_k, top_p=top_p, ): - if mpu.get_model_parallel_rank() == 0: + if (not neox_args.is_sequence_parallel and mpu.get_tensor_parallel_rank() == 0) or (neox_args.is_sequence_parallel and mpu.get_sequence_parallel_rank() == 0): generated_tokens = ( batch_context_tokens[0] .cpu() diff --git a/megatron/utils.py b/megatron/utils.py index 6a4545079..e19d3d610 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -131,9 +131,9 @@ def is_local_main(): return local_rank() == 0 -def is_mp_rank_0(): - """True if mp rank == 0""" - return mpu.get_model_parallel_rank() == 0 +def is_tp_rank_0(): + """True if tp rank == 0""" + return mpu.get_tensor_parallel_rank() == 0 def get_wandb_api_key(neox_args): @@ -391,8 +391,8 @@ def get_total_params(model): if mpu.get_data_parallel_rank() == 0: params = sum([p.nelement() for p in model.parameters()]) print( - " > number of parameters on model parallel rank {}: {}".format( - mpu.get_model_parallel_rank(), params + " > number of parameters on tensor parallel rank {}: {}".format( + mpu.get_tensor_parallel_rank(), params ), flush=True, ) diff --git a/tests/model/test_model_generation.py b/tests/model/test_model_generation.py index ab8bd756b..f95e6fed3 100644 --- a/tests/model/test_model_generation.py +++ b/tests/model/test_model_generation.py @@ -76,7 +76,7 @@ def wrapper(): def run_generate_test(param_dict, prompt): from megatron.text_generation_utils import generate_samples_from_prompt - from megatron.utils import is_mp_rank_0 + from megatron.utils import is_tp_rank_0 fixed_params = { "num_samples": 3, @@ -106,7 +106,7 @@ def run_generate_test(param_dict, prompt): ) # outputs only get generated on mp rank 0 - if is_mp_rank_0(): + if is_tp_rank_0(): assert len(output) == len(prompts) for prompt, out in zip(prompts, output): assert prompt == out["context"] From aa56a650d819e15893f1f30df604af7ec5961a57 Mon Sep 17 00:00:00 2001 From: github-actions Date: Sun, 26 Nov 2023 05:09:47 +0000 Subject: [PATCH 2/2] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index bc2e8fc57..dea38bfa3 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 2da1083 + Default = c509f6a current git hash of repository @@ -862,6 +862,14 @@ Parallelism Arguments +- **sequence_parallel_size**: int + + Default = 1 + + Size of the model parallelism. + + + - **pipe_partition_method**: str Default = type:transformer|mlp @@ -889,6 +897,15 @@ Parallelism Arguments +- **is_sequence_parallel**: bool + + Default = False + + flag to determine whether sequence parallelism is on - shouldn't be set by user, is automatically determined + according to sequence parallel size. + + + ## NeoXArgsTemplate NeoXArgsTemplate()