Skip to content

Commit

Permalink
Alias core.parallel_state as mpu and use it throughout code. RIP mpu.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredcasper committed Sep 24, 2022
1 parent c2ea914 commit 5942af9
Show file tree
Hide file tree
Showing 63 changed files with 273 additions and 319 deletions.
38 changes: 19 additions & 19 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import torch

from megatron import (core,
update_num_microbatches)
from megatron import update_num_microbatches
from megatron.core import mpu, tensor_parallel
from .global_vars import get_args
from .utils import (unwrap_model,
print_rank_0)
Expand Down Expand Up @@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,

# Use both the tensor and pipeline MP rank.
if pipeline_parallel is None:
pipeline_parallel = (core.get_pipeline_model_parallel_world_size() > 1)
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None:
tensor_rank = core.get_tensor_model_parallel_rank()
tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = core.get_pipeline_model_parallel_rank()
pipeline_rank = mpu.get_pipeline_model_parallel_rank()

# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
Expand All @@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join(
common_path + "_%03d" % core.get_data_parallel_rank(),
common_path + "_%03d" % mpu.get_data_parallel_rank(),
"optim.pt")
else:
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
Expand Down Expand Up @@ -185,18 +185,18 @@ def get_rng_state():
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': core.tensor_parallel.get_cuda_rng_tracker().get_states()}
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}

rng_state_list = None
if torch.distributed.is_initialized() and \
core.get_data_parallel_world_size() > 1 and \
mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
rng_state_list = \
[None for i in range(core.get_data_parallel_world_size())]
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
rng_state_list,
rng_state,
group=core.get_data_parallel_group())
group=mpu.get_data_parallel_group())
else:
rng_state_list = [rng_state]

Expand All @@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Collect args, model, RNG.
model_state_dict = {}
if not torch.distributed.is_initialized() \
or core.get_data_parallel_rank() == 0:
or mpu.get_data_parallel_rank() == 0:

# Arguments, iteration, and model.
model_state_dict['args'] = args
Expand All @@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
core.set_virtual_pipeline_model_parallel_rank(i)
mpu.set_virtual_pipeline_model_parallel_rank(i)
model_state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint()

Expand All @@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
optim_state_dict = {}
if not args.no_save_optim \
and (not torch.distributed.is_initialized()
or core.get_data_parallel_rank() == 0
or mpu.get_data_parallel_rank() == 0
or args.use_distributed_optimizer):

# Optimizer stuff.
Expand Down Expand Up @@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model[0].load_state_dict(model_state_dict['model'], strict=strict)
else:
for i in range(len(model)):
core.set_virtual_pipeline_model_parallel_rank(i)
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)

# Fix up query/key/value matrix ordering if needed
Expand Down Expand Up @@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# access rng_state for data parallel rank
if args.data_parallel_random_init:

rng_state = model_state_dict['rng_state'][core.get_data_parallel_rank()]
rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
else:
rng_state = model_state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state'])
Expand All @@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
core.tensor_parallel.get_cuda_rng_tracker().set_states(
tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(model_state_dict['random_rng_state'])
Expand All @@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not model_state_dict['rng_tracker_states']:
raise KeyError
core.tensor_parallel.get_cuda_rng_tracker().set_states(
tensor_parallel.get_cuda_rng_tracker().set_states(
model_state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
Expand Down Expand Up @@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args.use_distributed_optimizer,
release=False)

if core.get_data_parallel_rank() == 0:
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))

Expand All @@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier()

if core.get_data_parallel_rank() == 0:
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))

return model
25 changes: 12 additions & 13 deletions megatron/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from .parallel_state import (
initialize_model_parallel,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_pipeline_model_parallel_world_size,
get_pipeline_model_parallel_rank,
get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank,
get_data_parallel_world_size,
get_data_parallel_rank,
get_global_memory_buffer,
get_num_layers,
)
from megatron.core import tensor_parallel
import megatron.core.parallel_state
import megatron.core.tensor_parallel
import megatron.core.utils

# Alias parallel_state as mpu, its legacy name
mpu = parallel_state

__all__ = [
"parallel_state",
"tensor_parallel",
"utils",
]
4 changes: 0 additions & 4 deletions megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None

def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None


def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
Expand Down
9 changes: 7 additions & 2 deletions megatron/core/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
Expand All @@ -23,10 +24,14 @@
from .random import (
checkpoint,
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed
model_parallel_cuda_manual_seed,
)

from .utils import split_tensor_along_last_dim
from .utils import (
split_tensor_along_last_dim,
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)

__all__ = [
# cross_entropy.py
Expand Down
36 changes: 4 additions & 32 deletions megatron/core/tensor_parallel/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
get_tensor_model_parallel_world_size,
)

from .utils import (
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)

# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
Expand Down Expand Up @@ -55,38 +59,6 @@ def cb():
_lazy_call(cb)


def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks."""
partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data


def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
numel_gathered = torch.numel(tensor) * \
get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
return gathered


class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Expand Down
75 changes: 66 additions & 9 deletions megatron/core/tensor_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ def split_tensor_along_last_dim(
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
Expand All @@ -28,11 +32,64 @@ def split_tensor_along_last_dim(

return tensor_list

def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Arguments:
tensor: The tensor to split
Keyword Arguments:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size = torch.numel(tensor) // \
get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data


def gather_split_1d_tensor(tensor):
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * \
get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
return gathered


class VocabUtility:
"""Split the vocabulary into `world_size` chunks and return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)"""
""" Split the vocabulary into `world_size` chunks and return the first
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)
"""

@staticmethod
def vocab_range_from_per_partition_vocab_size(
Expand Down
29 changes: 0 additions & 29 deletions megatron/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,6 @@ def divide(numerator, denominator):
return numerator // denominator


def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = (
torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size()
)
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]


def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = parallel_state.get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(
numel_gathered,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed._all_gather_base(
gathered,
tensor,
group=parallel_state.get_tensor_model_parallel_group()
)
return gathered

class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
Expand Down
5 changes: 3 additions & 2 deletions megatron/data/biencoder_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import numpy as np
import torch

from megatron import get_args, get_tokenizer, mpu, print_rank_0
from megatron import get_args, get_tokenizer, print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.data.dataset_utils import create_masked_lm_predictions, \
pad_and_convert_to_numpy
from megatron.data.data_samplers import MegatronPretrainingSampler
Expand Down Expand Up @@ -57,7 +58,7 @@ def get_ict_batch(data_iterator):
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
data_b = tensor_parallel.broadcast_data(keys, data, datatype)

# Unpack.
query_tokens = data_b['query_tokens'].long()
Expand Down
Loading

0 comments on commit 5942af9

Please sign in to comment.