diff --git a/megatron/arguments.py b/megatron/arguments.py index 9813d2b964..a4a452bffc 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -64,10 +64,6 @@ def parse_args(extra_args_provider=None, defaults={}, args.pipeline_model_parallel_size = min( args.pipeline_model_parallel_size, (args.world_size // args.tensor_model_parallel_size)) - if args.pipeline_model_parallel_size > 1: - if "ring_exchange" not in dir(torch.distributed): - raise Exception('PyTorch with torch.distributed.ring_exchange ' - 'needed to run pipeline MP!') # Checks. model_parallel_size = args.pipeline_model_parallel_size * \ args.tensor_model_parallel_size diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 5b4cc2a30b..15492da1ea 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -36,6 +36,8 @@ from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_pipeline_model_parallel_first_rank from .initialize import get_pipeline_model_parallel_last_rank +from .initialize import get_pipeline_model_parallel_next_rank +from .initialize import get_pipeline_model_parallel_prev_rank from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size from .initialize import initialize_model_parallel diff --git a/megatron/mpu/initialize.py b/megatron/mpu/initialize.py index dcd8b41f0a..9fb829bfd2 100644 --- a/megatron/mpu/initialize.py +++ b/megatron/mpu/initialize.py @@ -276,16 +276,30 @@ def get_tensor_model_parallel_src_rank(): local_world_size = get_tensor_model_parallel_world_size() return (global_rank // local_world_size) * local_world_size +def get_pipeline_model_parallel_first_rank(): + assert _PIPELINE_GLOBAL_RANKS is not None, \ + "Pipeline parallel group is not initialized" + return _PIPELINE_GLOBAL_RANKS[0] + def get_pipeline_model_parallel_last_rank(): assert _PIPELINE_GLOBAL_RANKS is not None, \ "Pipeline parallel group is not initialized" last_rank_local = get_pipeline_model_parallel_world_size() - 1 return _PIPELINE_GLOBAL_RANKS[last_rank_local] -def get_pipeline_model_parallel_first_rank(): +def get_pipeline_model_parallel_next_rank(): assert _PIPELINE_GLOBAL_RANKS is not None, \ "Pipeline parallel group is not initialized" - return _PIPELINE_GLOBAL_RANKS[0] + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + +def get_pipeline_model_parallel_prev_rank(): + assert _PIPELINE_GLOBAL_RANKS is not None, \ + "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] def get_data_parallel_world_size(): """Return world size for the data parallel group.""" diff --git a/megatron/training.py b/megatron/training.py index a3d07835a4..64384a745f 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -325,7 +325,7 @@ def setup_model_and_optimizer(model_provider_func): def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward): - """Communicate tensors between stages using torch.distributed.ring_exchange(.) API.""" + """Communicate tensors between stages.""" args = get_args() # Create placeholder tensors for receive in forward and backward directions @@ -348,11 +348,26 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward) dtype=dtype) # Send tensors in both the forward and backward directions as appropriate. - torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, - tensor_recv_prev=tensor_recv_prev, - tensor_send_next=tensor_send_next, - tensor_recv_next=tensor_recv_next, - group=mpu.get_pipeline_model_parallel_group()) + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev, + mpu.get_pipeline_model_parallel_prev_rank()) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev, + mpu.get_pipeline_model_parallel_prev_rank()) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(recv_next_op) + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() return tensor_recv_prev, tensor_recv_next