Skip to content

Commit

Permalink
Move layers from mpu to core.tensor_parallel.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredcasper committed Sep 23, 2022
1 parent 209f91c commit c2ea914
Show file tree
Hide file tree
Showing 22 changed files with 508 additions and 326 deletions.
1 change: 0 additions & 1 deletion megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .global_vars import get_global_memory_buffer
from .initialize import initialize_megatron

from .utils import (print_rank_0,
Expand Down
36 changes: 18 additions & 18 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from megatron import (mpu,
from megatron import (core,
update_num_microbatches)
from .global_vars import get_args
from .utils import (unwrap_model,
Expand Down Expand Up @@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,

# Use both the tensor and pipeline MP rank.
if pipeline_parallel is None:
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
pipeline_parallel = (core.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank()
tensor_rank = core.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
pipeline_rank = core.get_pipeline_model_parallel_rank()

# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
Expand All @@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join(
common_path + "_%03d" % mpu.get_data_parallel_rank(),
common_path + "_%03d" % core.get_data_parallel_rank(),
"optim.pt")
else:
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
Expand Down Expand Up @@ -185,18 +185,18 @@ def get_rng_state():
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}
'rng_tracker_states': core.tensor_parallel.get_cuda_rng_tracker().get_states()}

rng_state_list = None
if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \
core.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
[None for i in range(core.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
rng_state_list,
rng_state,
group=mpu.get_data_parallel_group())
group=core.get_data_parallel_group())
else:
rng_state_list = [rng_state]

Expand All @@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Collect args, model, RNG.
model_state_dict = {}
if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0:
or core.get_data_parallel_rank() == 0:

# Arguments, iteration, and model.
model_state_dict['args'] = args
Expand All @@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
core.set_virtual_pipeline_model_parallel_rank(i)
model_state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint()

Expand All @@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
optim_state_dict = {}
if not args.no_save_optim \
and (not torch.distributed.is_initialized()
or mpu.get_data_parallel_rank() == 0
or core.get_data_parallel_rank() == 0
or args.use_distributed_optimizer):

# Optimizer stuff.
Expand Down Expand Up @@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model[0].load_state_dict(model_state_dict['model'], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
core.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)

# Fix up query/key/value matrix ordering if needed
Expand Down Expand Up @@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# access rng_state for data parallel rank
if args.data_parallel_random_init:

rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
rng_state = model_state_dict['rng_state'][core.get_data_parallel_rank()]
else:
rng_state = model_state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state'])
Expand All @@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
core.tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(model_state_dict['random_rng_state'])
Expand All @@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not model_state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
core.tensor_parallel.get_cuda_rng_tracker().set_states(
model_state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
Expand Down Expand Up @@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args.use_distributed_optimizer,
release=False)

if mpu.get_data_parallel_rank() == 0:
if core.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))

Expand All @@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier()

if mpu.get_data_parallel_rank() == 0:
if core.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))

return model
6 changes: 6 additions & 0 deletions megatron/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from .parallel_state import (
initialize_model_parallel,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_pipeline_model_parallel_world_size,
get_pipeline_model_parallel_rank,
get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank,
get_data_parallel_world_size,
get_data_parallel_rank,
get_global_memory_buffer,
get_num_layers,
)
from megatron.core import tensor_parallel
23 changes: 22 additions & 1 deletion megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from typing import Optional

from .utils import GlobalMemoryBuffer

# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
Expand Down Expand Up @@ -42,7 +44,8 @@
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None


# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None

def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
Expand Down Expand Up @@ -195,6 +198,12 @@ def initialize_model_parallel(
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks

# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer()


def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
Expand Down Expand Up @@ -506,6 +515,18 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())

def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()

def get_global_memory_buffer():
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER



def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
Expand Down
48 changes: 48 additions & 0 deletions megatron/core/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,57 @@
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data

from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
linear_with_grad_accumulation_and_async_allreduce

)

from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)

from .random import (
checkpoint,
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed
)

from .utils import split_tensor_along_last_dim

__all__ = [
# cross_entropy.py
"vocab_parallel_cross_entropy",
# data.py
"broadcast_data",
#layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
"linear_with_grad_accumulation_and_async_allreduce",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
]
Loading

0 comments on commit c2ea914

Please sign in to comment.