Skip to content

Commit

Permalink
Merge branch 'send_and_recv' into 'main'
Browse files Browse the repository at this point in the history
Use batched send and recv instead of torch.distributed.ring_exchange()

See merge request ADLR/megatron-lm!198
  • Loading branch information
jaredcasper committed Jan 5, 2021
2 parents 2348c99 + d899988 commit 6fa3684
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 12 deletions.
4 changes: 0 additions & 4 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ def parse_args(extra_args_provider=None, defaults={},
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
if args.pipeline_model_parallel_size > 1:
if "ring_exchange" not in dir(torch.distributed):
raise Exception('PyTorch with torch.distributed.ring_exchange '
'needed to run pipeline MP!')
# Checks.
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
Expand Down
2 changes: 2 additions & 0 deletions megatron/mpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_pipeline_model_parallel_next_rank
from .initialize import get_pipeline_model_parallel_prev_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import initialize_model_parallel
Expand Down
18 changes: 16 additions & 2 deletions megatron/mpu/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,16 +276,30 @@ def get_tensor_model_parallel_src_rank():
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size

def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]

def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]

def get_pipeline_model_parallel_first_rank():
def get_pipeline_model_parallel_next_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]

def get_pipeline_model_parallel_prev_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]

def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
Expand Down
27 changes: 21 additions & 6 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def setup_model_and_optimizer(model_provider_func):


def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
"""Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
"""Communicate tensors between stages."""
args = get_args()

# Create placeholder tensors for receive in forward and backward directions
Expand All @@ -348,11 +348,26 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
dtype=dtype)

# Send tensors in both the forward and backward directions as appropriate.
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_pipeline_model_parallel_group())
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()

return tensor_recv_prev, tensor_recv_next

Expand Down

0 comments on commit 6fa3684

Please sign in to comment.