diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 3ab0ad0b74..a839380710 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -9,8 +9,8 @@ import torch -from megatron import (core, - update_num_microbatches) +from megatron import update_num_microbatches +from megatron.core import mpu, tensor_parallel from .global_vars import get_args from .utils import (unwrap_model, print_rank_0) @@ -79,11 +79,11 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, # Use both the tensor and pipeline MP rank. if pipeline_parallel is None: - pipeline_parallel = (core.get_pipeline_model_parallel_world_size() > 1) + pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) if tensor_rank is None: - tensor_rank = core.get_tensor_model_parallel_rank() + tensor_rank = mpu.get_tensor_model_parallel_rank() if pipeline_rank is None: - pipeline_rank = core.get_pipeline_model_parallel_rank() + pipeline_rank = mpu.get_pipeline_model_parallel_rank() # Use both the tensor and pipeline MP rank. If using the distributed # optimizer, then the optimizer's path must additionally include the @@ -98,7 +98,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, if use_distributed_optimizer: model_name = os.path.join(common_path, "model_rng.pt") optim_name = os.path.join( - common_path + "_%03d" % core.get_data_parallel_rank(), + common_path + "_%03d" % mpu.get_data_parallel_rank(), "optim.pt") else: model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt") @@ -185,18 +185,18 @@ def get_rng_state(): 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state(), - 'rng_tracker_states': core.tensor_parallel.get_cuda_rng_tracker().get_states()} + 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()} rng_state_list = None if torch.distributed.is_initialized() and \ - core.get_data_parallel_world_size() > 1 and \ + mpu.get_data_parallel_world_size() > 1 and \ args.data_parallel_random_init: rng_state_list = \ - [None for i in range(core.get_data_parallel_world_size())] + [None for i in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather_object( rng_state_list, rng_state, - group=core.get_data_parallel_group()) + group=mpu.get_data_parallel_group()) else: rng_state_list = [rng_state] @@ -223,7 +223,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): # Collect args, model, RNG. model_state_dict = {} if not torch.distributed.is_initialized() \ - or core.get_data_parallel_rank() == 0: + or mpu.get_data_parallel_rank() == 0: # Arguments, iteration, and model. model_state_dict['args'] = args @@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): model_state_dict['model'] = model[0].state_dict_for_save_checkpoint() else: for i in range(len(model)): - core.set_virtual_pipeline_model_parallel_rank(i) + mpu.set_virtual_pipeline_model_parallel_rank(i) model_state_dict['model%d' % i] = \ model[i].state_dict_for_save_checkpoint() @@ -246,7 +246,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): optim_state_dict = {} if not args.no_save_optim \ and (not torch.distributed.is_initialized() - or core.get_data_parallel_rank() == 0 + or mpu.get_data_parallel_rank() == 0 or args.use_distributed_optimizer): # Optimizer stuff. @@ -548,7 +548,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri model[0].load_state_dict(model_state_dict['model'], strict=strict) else: for i in range(len(model)): - core.set_virtual_pipeline_model_parallel_rank(i) + mpu.set_virtual_pipeline_model_parallel_rank(i) model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict) # Fix up query/key/value matrix ordering if needed @@ -580,7 +580,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri # access rng_state for data parallel rank if args.data_parallel_random_init: - rng_state = model_state_dict['rng_state'][core.get_data_parallel_rank()] + rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()] else: rng_state = model_state_dict['rng_state'][0] random.setstate(rng_state['random_rng_state']) @@ -590,7 +590,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri # Check for empty states array if not rng_state['rng_tracker_states']: raise KeyError - core.tensor_parallel.get_cuda_rng_tracker().set_states( + tensor_parallel.get_cuda_rng_tracker().set_states( rng_state['rng_tracker_states']) else: # backward compatability random.setstate(model_state_dict['random_rng_state']) @@ -600,7 +600,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri # Check for empty states array if not model_state_dict['rng_tracker_states']: raise KeyError - core.tensor_parallel.get_cuda_rng_tracker().set_states( + tensor_parallel.get_cuda_rng_tracker().set_states( model_state_dict['rng_tracker_states']) except KeyError: print_rank_0('Unable to load rng state from checkpoint {}. ' @@ -640,7 +640,7 @@ def load_biencoder_checkpoint(model, only_query_model=False, args.use_distributed_optimizer, release=False) - if core.get_data_parallel_rank() == 0: + if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) @@ -656,7 +656,7 @@ def load_biencoder_checkpoint(model, only_query_model=False, model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() - if core.get_data_parallel_rank() == 0: + if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model diff --git a/megatron/core/__init__.py b/megatron/core/__init__.py index 113b24b435..cb437d5dae 100644 --- a/megatron/core/__init__.py +++ b/megatron/core/__init__.py @@ -1,13 +1,12 @@ -from .parallel_state import ( - initialize_model_parallel, - get_tensor_model_parallel_world_size, - get_tensor_model_parallel_rank, - get_pipeline_model_parallel_world_size, - get_pipeline_model_parallel_rank, - get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank, - get_data_parallel_world_size, - get_data_parallel_rank, - get_global_memory_buffer, - get_num_layers, -) -from megatron.core import tensor_parallel +import megatron.core.parallel_state +import megatron.core.tensor_parallel +import megatron.core.utils + +# Alias parallel_state as mpu, its legacy name +mpu = parallel_state + +__all__ = [ + "parallel_state", + "tensor_parallel", + "utils", +] diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 8119745a04..2caeb4bea1 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -47,10 +47,6 @@ # Memory buffers to avoid dynamic memory allocation _GLOBAL_MEMORY_BUFFER = None -def is_unitialized(): - """Useful for code segments that may be accessed with or without mpu initialization""" - return _DATA_PARALLEL_GROUP is None - def initialize_model_parallel( tensor_model_parallel_size: int = 1, diff --git a/megatron/core/tensor_parallel/__init__.py b/megatron/core/tensor_parallel/__init__.py index ae35aa07d0..4da239e5f5 100644 --- a/megatron/core/tensor_parallel/__init__.py +++ b/megatron/core/tensor_parallel/__init__.py @@ -5,6 +5,7 @@ ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, + set_tensor_model_parallel_attributes, set_defaults_if_not_set_tensor_model_parallel_attributes, copy_tensor_model_parallel_attributes, param_is_not_tensor_parallel_duplicate, @@ -23,10 +24,14 @@ from .random import ( checkpoint, get_cuda_rng_tracker, - model_parallel_cuda_manual_seed + model_parallel_cuda_manual_seed, ) -from .utils import split_tensor_along_last_dim +from .utils import ( + split_tensor_along_last_dim, + split_tensor_into_1d_equal_chunks, + gather_split_1d_tensor, +) __all__ = [ # cross_entropy.py diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index e0b8ae4347..228f208c8d 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -17,6 +17,10 @@ get_tensor_model_parallel_world_size, ) +from .utils import ( + split_tensor_into_1d_equal_chunks, + gather_split_1d_tensor, +) # Default name for the model parallel rng tracker. _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' @@ -55,38 +59,6 @@ def cb(): _lazy_call(cb) -def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): - """Break a tensor into equal 1D chunks.""" - partition_size = torch.numel(tensor) // \ - get_tensor_model_parallel_world_size() - start_index = partition_size * get_tensor_model_parallel_rank() - end_index = start_index + partition_size - if new_buffer: - data = torch.empty(partition_size, dtype=tensor.dtype, - device=torch.cuda.current_device(), - requires_grad=False) - data.copy_(tensor.view(-1)[start_index:end_index]) - else: - data = tensor.view(-1)[start_index:end_index] - return data - - -def gather_split_1d_tensor(tensor): - """Opposite of above function, gather values from model parallel ranks.""" - numel_gathered = torch.numel(tensor) * \ - get_tensor_model_parallel_world_size() - gathered = torch.empty(numel_gathered, dtype=tensor.dtype, - device=torch.cuda.current_device(), - requires_grad=False) - # TODO: This API is experimental in pytorch (as of Feb 2022) and - # this might break in future pytorch releases. We chose this API - # as opposed to torch.distributed.all_gather for efficiency reasons. - # This API calls directly NCCL all-gather versus the former does - # internal copies and can potentially cause slow down. - torch.distributed._all_gather_base(gathered, tensor, - group=get_tensor_model_parallel_group()) - return gathered - class CudaRNGStatesTracker: """Tracker for the cuda RNG states. diff --git a/megatron/core/tensor_parallel/utils.py b/megatron/core/tensor_parallel/utils.py index e5ad7a9360..5399a90b1d 100644 --- a/megatron/core/tensor_parallel/utils.py +++ b/megatron/core/tensor_parallel/utils.py @@ -10,12 +10,16 @@ def split_tensor_along_last_dim( num_partitions: int, contiguous_split_chunks: bool = False, ) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 @@ -28,11 +32,64 @@ def split_tensor_along_last_dim( return tensor_list +def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): + """ Break a tensor into equal 1D chunks across tensor parallel ranks. + + Returns a Tensor or View with this rank's portion of the data. + + Arguments: + tensor: The tensor to split + + Keyword Arguments: + new_buffer (bool): If True, returns a new Tensor. + If False, returns a view into the existing Tensor. + Default is False + + """ + partition_size = torch.numel(tensor) // \ + get_tensor_model_parallel_world_size() + start_index = partition_size * get_tensor_model_parallel_rank() + end_index = start_index + partition_size + if new_buffer: + data = torch.empty(partition_size, dtype=tensor.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + data.copy_(tensor.view(-1)[start_index:end_index]) + else: + data = tensor.view(-1)[start_index:end_index] + return data + + +def gather_split_1d_tensor(tensor): + """ Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor + model parallel ranks. + + Returns a new Tensor with the gathered data. + + Arguments: + tensor: A Tensor or view of this rank's portion of the data. + """ + numel_gathered = torch.numel(tensor) * \ + get_tensor_model_parallel_world_size() + gathered = torch.empty(numel_gathered, dtype=tensor.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + # TODO: This API is experimental in pytorch (as of Feb 2022) and + # this might break in future pytorch releases. We chose this API + # as opposed to torch.distributed.all_gather for efficiency reasons. + # This API calls directly NCCL all-gather versus the former does + # internal copies and can potentially cause slow down. + torch.distributed._all_gather_base(gathered, tensor, + group=get_tensor_model_parallel_group()) + return gathered + class VocabUtility: - """Split the vocabulary into `world_size` chunks and return the - first and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last)""" + """ Split the vocabulary into `world_size` chunks and return the first + and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last) + + """ @staticmethod def vocab_range_from_per_partition_vocab_size( diff --git a/megatron/core/utils.py b/megatron/core/utils.py index eb13053fc6..f58f961fd0 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -21,35 +21,6 @@ def divide(numerator, denominator): return numerator // denominator -def split_tensor_into_1d_equal_chunks(tensor): - """Break a tensor into equal 1D chunks.""" - data = tensor.view(-1) - partition_size = ( - torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size() - ) - start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() - end_index = start_index + partition_size - return data[start_index:end_index] - - -def gather_split_1d_tensor(tensor): - """Opposite of above function, gather values from model parallel ranks.""" - world_size = parallel_state.get_tensor_model_parallel_world_size() - numel = torch.numel(tensor) - numel_gathered = world_size * numel - gathered = torch.empty( - numel_gathered, - dtype=tensor.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - torch.distributed._all_gather_base( - gathered, - tensor, - group=parallel_state.get_tensor_model_parallel_group() - ) - return gathered - class GlobalMemoryBuffer: """Global buffer to avoid dynamic memory allocations. Caller should ensure that buffers of the same name diff --git a/megatron/data/biencoder_dataset_utils.py b/megatron/data/biencoder_dataset_utils.py index f7b3b961b8..c08f067923 100644 --- a/megatron/data/biencoder_dataset_utils.py +++ b/megatron/data/biencoder_dataset_utils.py @@ -4,7 +4,8 @@ import numpy as np import torch -from megatron import get_args, get_tokenizer, mpu, print_rank_0 +from megatron import get_args, get_tokenizer, print_rank_0 +from megatron.core import mpu, tensor_parallel from megatron.data.dataset_utils import create_masked_lm_predictions, \ pad_and_convert_to_numpy from megatron.data.data_samplers import MegatronPretrainingSampler @@ -57,7 +58,7 @@ def get_ict_batch(data_iterator): data = None else: data = next(data_iterator) - data_b = mpu.broadcast_data(keys, data, datatype) + data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. query_tokens = data_b['query_tokens'].long() diff --git a/megatron/data/blendable_dataset.py b/megatron/data/blendable_dataset.py index eeadf87127..6b642bccac 100644 --- a/megatron/data/blendable_dataset.py +++ b/megatron/data/blendable_dataset.py @@ -8,8 +8,6 @@ import torch from megatron import print_rank_0 -from megatron import mpu - class BlendableDataset(torch.utils.data.Dataset): diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py index fa4a9d1bea..8dec2c1922 100644 --- a/megatron/data/data_samplers.py +++ b/megatron/data/data_samplers.py @@ -8,7 +8,7 @@ import numpy as np from torch.utils.data import Dataset from megatron import get_args -from megatron import mpu +from megatron.core import mpu def build_pretraining_data_loader(dataset, consumed_samples): diff --git a/megatron/data/dataset_utils.py b/megatron/data/dataset_utils.py index 426e965c85..cb13bb225a 100644 --- a/megatron/data/dataset_utils.py +++ b/megatron/data/dataset_utils.py @@ -28,9 +28,9 @@ from megatron import ( get_args, - mpu, print_rank_0 ) +from megatron.core import mpu from megatron.data.blendable_dataset import BlendableDataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index 1c44006991..f426965ee5 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -8,7 +8,8 @@ import numpy as np import torch -from megatron import mpu, print_rank_0 +from megatron import print_rank_0 +from megatron.core import mpu from megatron.data.blendable_dataset import BlendableDataset from megatron.data.dataset_utils import get_datasets_weights_and_num_samples from megatron.data.dataset_utils import get_train_valid_test_split_ diff --git a/megatron/data/orqa_wiki_dataset.py b/megatron/data/orqa_wiki_dataset.py index b2e10f3300..4019cd764c 100644 --- a/megatron/data/orqa_wiki_dataset.py +++ b/megatron/data/orqa_wiki_dataset.py @@ -9,7 +9,8 @@ import torch from torch.utils.data import Dataset -from megatron import print_rank_0, get_args, get_tokenizer, mpu +from megatron import print_rank_0, get_args, get_tokenizer +from megatron.core import tensor_parallel from megatron.data.biencoder_dataset_utils import make_attention_mask def get_open_retrieval_wiki_dataset(): @@ -32,7 +33,7 @@ def get_open_retrieval_batch(data_iterator): # Broadcast data. data = None if data_iterator is None else next(data_iterator) - data_b = mpu.broadcast_data(keys, data, datatype) + data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. row_id = data_b['row_id'].long() diff --git a/megatron/data/realm_dataset_utils.py b/megatron/data/realm_dataset_utils.py index aecf5549a7..21445573e3 100644 --- a/megatron/data/realm_dataset_utils.py +++ b/megatron/data/realm_dataset_utils.py @@ -4,9 +4,10 @@ import numpy as np import torch -from megatron import mpu, print_rank_0 +from megatron import print_rank_0 +from megatron.core import mpu, tensor_parallel from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy -from megatron import get_args, get_tokenizer, print_rank_0, mpu +from megatron import get_args, get_tokenizer, print_rank_0 def get_one_epoch_dataloader(dataset, micro_batch_size=None): @@ -47,7 +48,7 @@ def get_ict_batch(data_iterator): data = None else: data = next(data_iterator) - data_b = mpu.broadcast_data(keys, data, datatype) + data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. query_tokens = data_b['query_tokens'].long() diff --git a/megatron/data/realm_index.py b/megatron/data/realm_index.py index a4b543c7e0..1fa4a309ed 100644 --- a/megatron/data/realm_index.py +++ b/megatron/data/realm_index.py @@ -7,7 +7,7 @@ import torch from megatron import get_args -from megatron import mpu +from megatron.core import mpu def detach(tensor): @@ -50,10 +50,10 @@ def clear(self): def load_from_file(self): """Populate members from instance saved to file""" - if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: print("\n> Unpickling BlockData", flush=True) state_dict = pickle.load(open(self.embedding_path, 'rb')) - if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: print(">> Finished unpickling BlockData\n", flush=True) self.embed_data = state_dict['embed_data'] @@ -137,7 +137,7 @@ def _set_mips_index(self): except ImportError: raise Exception("Error: Please install faiss to use FaissMIPSIndex") - if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: print("\n> Building index", flush=True) cpu_index = faiss.IndexFlatIP(self.embed_size) @@ -149,12 +149,12 @@ def _set_mips_index(self): config.useFloat16 = True gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config) self.mips_index = faiss.IndexIDMap(gpu_index) - if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on GPU", flush=True) else: # CPU index supports IDs so wrap with IDMap self.mips_index = faiss.IndexIDMap(cpu_index) - if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on CPU", flush=True) # if we were constructed with a BlockData, then automatically load it @@ -199,7 +199,7 @@ def add_embed_data(self, all_embed_data): self.mips_index.add_with_ids(embeds_arr, indices_arr) - if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: print(">>> Finished adding block data to index", flush=True) def search_mips_index(self, query_embeds, top_k, reconstruct=True): diff --git a/megatron/indexer.py b/megatron/indexer.py index d2ff9e36f8..45f530a7d4 100644 --- a/megatron/indexer.py +++ b/megatron/indexer.py @@ -4,7 +4,7 @@ import torch.distributed as dist from megatron import get_args, print_rank_0 -from megatron import mpu +from megatron.core import mpu from megatron.checkpointing import load_biencoder_checkpoint from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch diff --git a/megatron/initialize.py b/megatron/initialize.py index d7bc629b7e..ba68a8fd5b 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -14,13 +14,10 @@ from megatron import get_adlr_autoresume from megatron import get_args from megatron import get_tensorboard_writer -from megatron import mpu -from megatron import core +from megatron.core import mpu, tensor_parallel from megatron.arguments import (parse_args, validate_args) from megatron.checkpointing import load_args_from_checkpoint from megatron.global_vars import set_global_variables -from megatron.mpu import (set_tensor_model_parallel_rank, - set_tensor_model_parallel_world_size) from megatron.model.transformer import bias_dropout_add_fused_train from megatron.model.fused_bias_gelu import bias_gelu @@ -65,13 +62,14 @@ def finish_mpu_init(): args = get_args() if args.lazy_mpu_init: + # TODO is this still a necessary option? args.use_cpu_initialization=True # delayed initialization of DDP-related stuff - # We only set basic DDP globals - set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) + # We only set basic DDP globals + mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) # and return function for external DDP manager # to call when it has DDP initialized - set_tensor_model_parallel_rank(args.rank) + mpu.set_tensor_model_parallel_rank(args.rank) return finish_mpu_init else: # Megatron's MPU is the master. Complete initialization right away. @@ -147,7 +145,7 @@ def _compile_dependencies(): def _initialize_distributed(): - """Initialize torch.distributed and mpu.""" + """Initialize torch.distributed and core model parallel.""" args = get_args() device_count = torch.cuda.device_count() @@ -185,17 +183,14 @@ def _initialize_distributed(): print('model parallel is already initialized') else: mpu.initialize_model_parallel(args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, - args.virtual_pipeline_model_parallel_size, - args.pipeline_model_parallel_split_rank) - core.initialize_model_parallel(args.tensor_model_parallel_size, args.pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size, args.pipeline_model_parallel_split_rank) - print(f'> initialized tensor model parallel with size ' - f'{core.get_tensor_model_parallel_world_size()}') - print(f'> initialized pipeline model parallel with size ' - f'{core.get_pipeline_model_parallel_world_size()}') + if args.rank == 0: + print(f'> initialized tensor model parallel with size ' + f'{mpu.get_tensor_model_parallel_world_size()}') + print(f'> initialized pipeline model parallel with size ' + f'{mpu.get_pipeline_model_parallel_world_size()}') def _init_autoresume(): @@ -219,7 +214,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False): np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.device_count() > 0: - core.tensor_parallel.model_parallel_cuda_manual_seed(seed) + tensor_parallel.model_parallel_cuda_manual_seed(seed) else: raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) diff --git a/megatron/model/bert_model.py b/megatron/model/bert_model.py index 921356a06c..c05f1a71cd 100644 --- a/megatron/model/bert_model.py +++ b/megatron/model/bert_model.py @@ -5,7 +5,7 @@ import torch from megatron import get_args -from megatron import core +from megatron.core import tensor_parallel from megatron.model.enums import AttnMaskType from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model @@ -61,7 +61,7 @@ def __init__(self, mpu_vocab_size, hidden_size, init_method, args = get_args() self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) - mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) self.parallel_output = parallel_output self.dense = get_linear_layer(hidden_size, hidden_size, init_method) @@ -110,9 +110,9 @@ def post_language_model_processing(lm_output, pooled_output, # lm_logits : [s, b, h] and lm_labels: [s, b] if fp16_lm_cross_entropy: assert lm_logits.dtype == torch.half - lm_loss = core.vocab_parallel_cross_entropy(lm_logits, lm_labels) + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels) else: - lm_loss = core.vocab_parallel_cross_entropy(lm_logits.float(), + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels) # [s, b] => [b s] lm_loss = lm_loss.transpose(0,1).contiguous() diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py index 9d10e948e4..c910879dc8 100644 --- a/megatron/model/biencoder_model.py +++ b/megatron/model/biencoder_model.py @@ -2,11 +2,11 @@ import torch import sys -from megatron import get_args, print_rank_0 +from megatron import get_args, print_rank_0, get_tokenizer +from megatron.core import mpu from megatron.checkpointing import fix_query_key_value_ordering from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import get_checkpoint_name -from megatron import mpu, get_tokenizer from megatron.model.bert_model import bert_position_ids from megatron.model.enums import AttnMaskType from megatron.model.language_model import get_language_model diff --git a/megatron/model/classification.py b/megatron/model/classification.py index 93bd3c8555..54a452065a 100644 --- a/megatron/model/classification.py +++ b/megatron/model/classification.py @@ -5,7 +5,6 @@ import torch from megatron import get_args, print_rank_last -from megatron import mpu from megatron.model.enums import AttnMaskType from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids from megatron.model.language_model import get_language_model diff --git a/megatron/model/distributed.py b/megatron/model/distributed.py index f55de1d891..f91f8a63e3 100644 --- a/megatron/model/distributed.py +++ b/megatron/model/distributed.py @@ -8,7 +8,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from megatron import get_args -from megatron import mpu +from megatron.core import mpu from .module import MegatronModule diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 15fc0b6c15..06b59791e6 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -5,8 +5,7 @@ import torch from megatron import get_args -from megatron import mpu -from megatron import core +from megatron.core import tensor_parallel from .module import MegatronModule from .enums import AttnMaskType @@ -34,9 +33,9 @@ def post_language_model_processing(lm_output, labels, logit_weights, labels = labels.transpose(0,1).contiguous() if fp16_lm_cross_entropy: assert output.dtype == torch.half - loss = core.tensor_parallel.vocab_parallel_cross_entropy(output, labels) + loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) else: - loss = core.tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) + loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) # [s b] => [b, s] loss = loss.transpose(0,1).contiguous() diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 9bc4d71ffd..7888153cd8 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from megatron import get_args -from megatron import core +from megatron.core import mpu, tensor_parallel from .module import MegatronModule from megatron.model.enums import LayerType, AttnMaskType from megatron.model.transformer import ParallelTransformer @@ -22,15 +22,15 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, if args.async_tensor_model_parallel_allreduce or\ args.sequence_parallel: input_parallel = input_ - model_parallel = core.get_tensor_model_parallel_world_size() > 1 + model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ model_parallel and not args.sequence_parallel else: - input_parallel = core.tensor_parallel.copy_to_tensor_model_parallel_region(input_) + input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_) async_grad_allreduce = False # Matrix multiply. - logits_parallel = core.tensor_parallel.linear_with_grad_accumulation_and_async_allreduce( + logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, weight=word_embeddings_weight, bias=bias, @@ -42,7 +42,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, if parallel_output: return logits_parallel - return core.tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) + return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) def get_language_model(num_tokentypes, add_pooler, @@ -106,7 +106,7 @@ def forward(self, hidden_states, sequence_index=0): # gather data along sequence dimensions # same pooler is run on all tensor parallel nodes if self.sequence_parallel: - hidden_states = core.tensor_parallel.gather_from_sequence_parallel_region( + hidden_states = tensor_parallel.gather_from_sequence_parallel_region( hidden_states, tensor_parallel_output_grad=False) @@ -146,7 +146,7 @@ def __init__(self, args = get_args() # Word embeddings (parallel). - self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding( + self.word_embeddings = tensor_parallel.VocabParallelEmbedding( vocab_size, self.hidden_size, init_method=self.init_method, params_dtype=args.params_dtype, @@ -229,8 +229,8 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): # Dropout. if self.sequence_parallel: - embeddings = core.tensor_parallel.scatter_to_sequence_parallel_region(embeddings) - with core.tensor_parallel.get_cuda_rng_tracker().fork(): + embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + with tensor_parallel.get_cuda_rng_tracker().fork(): embeddings = self.embedding_dropout(embeddings) else: embeddings = self.embedding_dropout(embeddings) diff --git a/megatron/model/module.py b/megatron/model/module.py index 834cc20a8b..1c254181bd 100644 --- a/megatron/model/module.py +++ b/megatron/model/module.py @@ -7,8 +7,7 @@ from torch.nn.parameter import Parameter from megatron import get_args -from megatron import mpu -from megatron import core +from megatron.core import mpu, tensor_parallel _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) @@ -77,7 +76,7 @@ def initialize_word_embeddings(self, init_method_normal): self._word_embeddings_for_head_key = 'word_embeddings_for_head' # set word_embeddings weights to 0 here, then copy first # stage's weights using all_reduce below. - self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding( + self.word_embeddings = tensor_parallel.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std), params_dtype=args.params_dtype, diff --git a/megatron/model/multiple_choice.py b/megatron/model/multiple_choice.py index d9ba4334df..6af06240d4 100644 --- a/megatron/model/multiple_choice.py +++ b/megatron/model/multiple_choice.py @@ -5,7 +5,6 @@ import torch from megatron import get_args, print_rank_last -from megatron import mpu from megatron.model.enums import AttnMaskType from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids from megatron.model.language_model import get_language_model diff --git a/megatron/model/realm_model.py b/megatron/model/realm_model.py index fa40e54b84..654f2992f6 100644 --- a/megatron/model/realm_model.py +++ b/megatron/model/realm_model.py @@ -5,7 +5,7 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.model import BertModel from .module import MegatronModule -from megatron import mpu +from megatron.core import mpu from megatron.model.enums import AttnMaskType from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal diff --git a/megatron/model/t5_model.py b/megatron/model/t5_model.py index 14e5b6faa5..ab6001f5b3 100644 --- a/megatron/model/t5_model.py +++ b/megatron/model/t5_model.py @@ -4,10 +4,8 @@ import torch -from megatron import ( - get_args, - mpu -) +from megatron import get_args +from megatron.core import tensor_parallel from megatron.model.enums import AttnMaskType from megatron.model.language_model import parallel_lm_logits, get_language_model from megatron.model.transformer import LayerNorm @@ -151,10 +149,10 @@ def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask, lm_labels = lm_labels.transpose(0,1).contiguous() if self.fp16_lm_cross_entropy: assert lm_logits.dtype == torch.half - lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels) else: - lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), - lm_labels) + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), + lm_labels) # [s b] => [b s] lm_loss = lm_loss.transpose(0,1).contiguous() return lm_loss diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index f6e11aaad5..017beb49ee 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -6,10 +6,9 @@ import torch import torch.nn.functional as F -from megatron import get_timers, get_args -from megatron.core import get_global_memory_buffer -from megatron import core +from megatron import get_timers, get_args, core from .module import MegatronModule +from megatron.core import mpu, tensor_parallel from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.model import LayerNorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax @@ -79,7 +78,7 @@ def __init__(self, init_method, output_layer_init_method): # Project to 4h. - self.dense_h_to_4h = core.tensor_parallel.ColumnParallelLinear( + self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( args.hidden_size, args.ffn_hidden_size, gather_output=False, @@ -96,7 +95,7 @@ def __init__(self, init_method, output_layer_init_method): self.activation_func = erf_gelu # Project back to h. - self.dense_4h_to_h = core.tensor_parallel.RowParallelLinear( + self.dense_4h_to_h = tensor_parallel.RowParallelLinear( args.ffn_hidden_size, args.hidden_size, input_is_parallel=True, @@ -189,7 +188,7 @@ def __init__(self, layer_number, projection_size = args.kv_channels * args.num_attention_heads # Per attention head and per partition values. - world_size = core.get_tensor_model_parallel_world_size() + world_size = mpu.get_tensor_model_parallel_world_size() self.hidden_size_per_partition = core.utils.divide(projection_size, world_size) self.hidden_size_per_attention_head = core.utils.divide( @@ -237,7 +236,7 @@ def forward(self, query_layer, key_layer, output_size[0] * output_size[1], -1) # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = get_global_memory_buffer().get_tensor( + matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor( (output_size[0]*output_size[1], output_size[2], output_size[3]), query_layer.dtype, "mpu") @@ -263,7 +262,7 @@ def forward(self, query_layer, key_layer, # seem a bit unusual, but is taken from the original Transformer paper. if not self.sequence_parallel: - with core.tensor_parallel.get_cuda_rng_tracker().fork(): + with tensor_parallel.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) else: attention_probs = self.attention_dropout(attention_probs) @@ -327,7 +326,7 @@ def __init__(self, init_method, projection_size = args.kv_channels * args.num_attention_heads # Per attention head and per partition values. - world_size = core.get_tensor_model_parallel_world_size() + world_size = mpu.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = core.utils.divide( projection_size, args.num_attention_heads) self.num_attention_heads_per_partition = core.utils.divide( @@ -335,7 +334,7 @@ def __init__(self, init_method, # Strided linear layer. if attention_type == AttnType.self_attn: - self.query_key_value = core.tensor_parallel.ColumnParallelLinear( + self.query_key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 3 * projection_size, gather_output=False, @@ -344,7 +343,7 @@ def __init__(self, init_method, **_args_to_kwargs()) else: assert attention_type == AttnType.cross_attn - self.query = core.tensor_parallel.ColumnParallelLinear( + self.query = tensor_parallel.ColumnParallelLinear( args.hidden_size, projection_size, gather_output=False, @@ -353,7 +352,7 @@ def __init__(self, init_method, **_args_to_kwargs()) - self.key_value = core.tensor_parallel.ColumnParallelLinear( + self.key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 2 * projection_size, gather_output=False, @@ -366,7 +365,7 @@ def __init__(self, init_method, self.checkpoint_core_attention = args.recompute_granularity == 'selective' # Output. - self.dense = core.tensor_parallel.RowParallelLinear( + self.dense = tensor_parallel.RowParallelLinear( projection_size, args.hidden_size, input_is_parallel=True, @@ -386,7 +385,7 @@ def custom_forward(*inputs): value_layer, attention_mask) return output_ - hidden_states = core.tensor_parallel.checkpoint( + hidden_states = tensor_parallel.checkpoint( custom_forward, False, query_layer, key_layer, value_layer, attention_mask) @@ -439,7 +438,7 @@ def forward(self, hidden_states, attention_mask, # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, - value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3) + value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3) else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) @@ -452,7 +451,7 @@ def forward(self, hidden_states, attention_mask, # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] (key_layer, - value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) + value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) # Attention head [sq, b, h] --> [sq, b, hp] query_layer, _ = self.query(hidden_states) @@ -769,7 +768,7 @@ def __init__(self, init_method, output_layer_init_method, self.sequence_parallel = args.sequence_parallel # Number of layers. - self.num_layers = core.get_num_layers( + self.num_layers = mpu.get_num_layers( args, args.model_type == ModelType.encoder_and_decoder) self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)] @@ -799,21 +798,21 @@ def build_layer(layer_number): # layers to stages like (each list is a model chunk): # Stage 0: [0, 1] [4, 5] # Stage 1: [2, 3] [6, 7] - offset = core.get_virtual_pipeline_model_parallel_rank() * ( + offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( args.num_layers // args.virtual_pipeline_model_parallel_size) + \ - (core.get_pipeline_model_parallel_rank() * self.num_layers) + (mpu.get_pipeline_model_parallel_rank() * self.num_layers) else: # Each stage gets a contiguous set of layers. if args.model_type == ModelType.encoder_and_decoder and \ - core.get_pipeline_model_parallel_world_size() > 1: - pipeline_rank = core.get_pipeline_model_parallel_rank() + mpu.get_pipeline_model_parallel_world_size() > 1: + pipeline_rank = mpu.get_pipeline_model_parallel_rank() if layer_type == LayerType.encoder: offset = pipeline_rank * self.num_layers else: num_ranks_in_enc = args.pipeline_model_parallel_split_rank offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers else: - offset = core.get_pipeline_model_parallel_rank() * self.num_layers + offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers if self.num_layers == 0: # When a standalone embedding stage is used (e.g., @@ -862,7 +861,7 @@ def custom_forward(*inputs): # A method to further reduce memory usage reducing checkpoints. l = 0 while l < self.num_layers: - hidden_states = core.tensor_parallel.checkpoint( + hidden_states = tensor_parallel.checkpoint( custom(l, l + self.recompute_num_layers), self.distribute_saved_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) @@ -874,7 +873,7 @@ def custom_forward(*inputs): # A method fully use the device memory removing redundant re-computation. for l in range(self.num_layers): if l < self.recompute_num_layers: - hidden_states = core.tensor_parallel.checkpoint( + hidden_states = tensor_parallel.checkpoint( custom(l, l + 1), self.distribute_saved_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) @@ -932,7 +931,7 @@ def forward(self, hidden_states, attention_mask, ) if self.sequence_parallel: - rng_context = core.tensor_parallel.get_cuda_rng_tracker().fork() + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() else: rng_context = nullcontext() diff --git a/megatron/model/vision/knn_monitor.py b/megatron/model/vision/knn_monitor.py index d1a7588008..a7d79854eb 100644 --- a/megatron/model/vision/knn_monitor.py +++ b/megatron/model/vision/knn_monitor.py @@ -1,6 +1,7 @@ import torch.nn.functional as F import torch -from megatron import print_rank_0, get_args, mpu +from megatron import print_rank_0, get_args +from megatron.core import mpu from megatron.data.vit_dataset import ClassificationTransform from megatron.data.image_folder import ImageFolder diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py deleted file mode 100644 index 9c42b5f87a..0000000000 --- a/megatron/mpu/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Model parallel utility interface.""" - -from .initialize import is_unitialized -from .initialize import destroy_model_parallel -from .initialize import get_data_parallel_group -from .initialize import get_data_parallel_rank -from .initialize import get_data_parallel_world_size -from .initialize import get_embedding_group -from .initialize import get_position_embedding_group -from .initialize import get_model_parallel_group -from .initialize import get_tensor_model_parallel_group -from .initialize import get_pipeline_model_parallel_group -from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank -from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank -from .initialize import is_pipeline_first_stage, is_pipeline_last_stage -from .initialize import is_rank_in_embedding_group -from .initialize import is_rank_in_position_embedding_group -from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split -from .initialize import is_pipeline_stage_at_split -from .initialize import get_num_layers -from .initialize import get_tensor_model_parallel_src_rank -from .initialize import get_data_parallel_src_rank -from .initialize import get_pipeline_model_parallel_first_rank -from .initialize import get_pipeline_model_parallel_last_rank -from .initialize import get_pipeline_model_parallel_next_rank -from .initialize import get_pipeline_model_parallel_prev_rank -from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size -from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size -from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank -from .initialize import initialize_model_parallel -from .initialize import model_parallel_is_initialized - - -from .utils import divide -from .utils import split_tensor_along_last_dim diff --git a/megatron/optimizer/distrib_optimizer.py b/megatron/optimizer/distrib_optimizer.py index 6162a3e6ee..2bfe60ff0c 100644 --- a/megatron/optimizer/distrib_optimizer.py +++ b/megatron/optimizer/distrib_optimizer.py @@ -8,10 +8,9 @@ from megatron import get_args from megatron import get_timers -from megatron import mpu from megatron import print_rank_0 +from megatron.core import mpu, tensor_parallel from megatron.model.module import param_is_not_shared -from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper @@ -290,9 +289,9 @@ def build_model_and_main_param_groups(cls, shard_model_param = model_param.detach().view(-1) \ [param_range.start:param_range.end] shard_main_param = shard_model_param.clone().float() - mpu.copy_tensor_model_parallel_attributes( + tensor_parallel.copy_tensor_model_parallel_attributes( shard_model_param, model_param) - mpu.copy_tensor_model_parallel_attributes( + tensor_parallel.copy_tensor_model_parallel_attributes( shard_main_param, model_param) if hasattr(model_param, 'shared'): shard_model_param.shared = model_param.shared @@ -309,7 +308,7 @@ def build_model_and_main_param_groups(cls, [param_range.start:param_range.end] model_fp32_params_this_group.append(model_param) shard_fp32_params_this_group.append(shard_model_param) - mpu.copy_tensor_model_parallel_attributes( + tensor_parallel.copy_tensor_model_parallel_attributes( shard_model_param, model_param) if hasattr(model_param, 'shared'): shard_model_param.shared = model_param.shared diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index 7c55cdcde6..cdb9c7eaf5 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -11,13 +11,11 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from megatron import get_timers -from megatron import mpu -from megatron import core from megatron import print_rank_0 +from megatron.core import mpu, tensor_parallel from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import Float16Module from megatron.model.module import param_is_not_shared -from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate from megatron.utils import unwrap_model from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 @@ -103,7 +101,7 @@ def get_main_grads_for_grad_norm(self): grad = param.grad grad_not_none = grad is not None is_not_shared = param_is_not_shared(param) - is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param) if grad_not_none and is_not_shared and is_not_tp_duplicate: grads_for_norm.append(grad) @@ -528,8 +526,8 @@ def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, # Create a copy main_param = param.detach().clone().float() # Copy tensor model parallel attributes. - core.tensor_parallel.copy_tensor_model_parallel_attributes(main_param, - param) + tensor_parallel.copy_tensor_model_parallel_attributes(main_param, + param) if hasattr(param, 'shared'): main_param.shared = param.shared # Replace the optimizer params with the new fp32 copy. diff --git a/megatron/p2p_communication.py b/megatron/p2p_communication.py index 9b478d3754..ba2d55a0d2 100644 --- a/megatron/p2p_communication.py +++ b/megatron/p2p_communication.py @@ -4,8 +4,8 @@ import operator import torch -from megatron import get_args -from megatron import mpu +from megatron import get_args, core +from megatron.core import mpu def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, @@ -81,10 +81,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, args.scatter_gather_tensors_in_pipeline and \ not args.sequence_parallel: if tensor_send_next is not None: - tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) + tensor_send_next = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_next) if tensor_send_prev is not None: - tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) + tensor_send_prev = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_prev) # Send tensors in both the forward and backward directions as appropriate. if args.use_ring_exchange_p2p: @@ -127,18 +127,18 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, args.scatter_gather_tensors_in_pipeline and \ not args.sequence_parallel: if recv_prev: - tensor_recv_prev = mpu.gather_split_1d_tensor( + tensor_recv_prev = core.tensor_parallel.gather_split_1d_tensor( tensor_recv_prev).view(tensor_shape).requires_grad_() - tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev, - requires_grad = True, - keep_graph = False) + tensor_recv_prev = core.utils.make_viewless_tensor(tensor_recv_prev, + requires_grad = True, + keep_graph = False) if recv_next: - tensor_recv_next = mpu.gather_split_1d_tensor( + tensor_recv_next = core.tensor_parallel.gather_split_1d_tensor( tensor_recv_next).view(tensor_shape).requires_grad_() - tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next, - requires_grad = True, - keep_graph = False) + tensor_recv_next = core.utils.make_viewless_tensor(tensor_recv_next, + requires_grad = True, + keep_graph = False) return tensor_recv_prev, tensor_recv_next diff --git a/megatron/schedules.py b/megatron/schedules.py index b316ade7e8..fd2d378daf 100644 --- a/megatron/schedules.py +++ b/megatron/schedules.py @@ -8,8 +8,8 @@ from megatron import get_args from megatron import get_num_microbatches from megatron import get_timers -from megatron import mpu from megatron import p2p_communication +from megatron.core import mpu from megatron.utils import unwrap_model from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import Float16Module diff --git a/megatron/text_generation/api.py b/megatron/text_generation/api.py index c9dac6e867..a9746000c6 100644 --- a/megatron/text_generation/api.py +++ b/megatron/text_generation/api.py @@ -5,7 +5,7 @@ import torch -from megatron import mpu +from megatron.core import mpu from .communication import broadcast_float_list from .generation import ( generate_tokens_probs_and_return_on_first_stage, diff --git a/megatron/text_generation/communication.py b/megatron/text_generation/communication.py index 45189c9a8b..dee32077f3 100644 --- a/megatron/text_generation/communication.py +++ b/megatron/text_generation/communication.py @@ -5,7 +5,7 @@ import torch -from megatron import mpu +from megatron.core import mpu diff --git a/megatron/text_generation/forward_step.py b/megatron/text_generation/forward_step.py index 97544c8761..feb087cbb6 100644 --- a/megatron/text_generation/forward_step.py +++ b/megatron/text_generation/forward_step.py @@ -6,9 +6,8 @@ import torch -from megatron import ( - get_args, - mpu) +from megatron import get_args +from megatron.core import mpu from .communication import ( send_to_next_pipeline_rank, recv_from_prev_pipeline_rank_) diff --git a/megatron/text_generation/generation.py b/megatron/text_generation/generation.py index a366f19325..b06c287b50 100644 --- a/megatron/text_generation/generation.py +++ b/megatron/text_generation/generation.py @@ -5,7 +5,8 @@ import torch import torch.nn.functional as F -from megatron import get_args, get_tokenizer, mpu +from megatron import get_args, get_tokenizer +from megatron.core import mpu from megatron.utils import get_ltor_masks_and_position_ids from .communication import ( copy_from_last_to_first_pipeline_stage, diff --git a/megatron/training.py b/megatron/training.py index 1a877a1f50..2e90a681b1 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -19,8 +19,7 @@ from megatron import get_num_microbatches from megatron import is_last_rank from megatron import update_num_microbatches -from megatron import mpu -from megatron import core +from megatron.core import mpu, tensor_parallel from megatron import print_rank_0 from megatron import print_rank_last from megatron.checkpointing import load_checkpoint @@ -258,7 +257,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): - core.tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if mpu.get_data_parallel_rank() == 0: diff --git a/megatron/utils.py b/megatron/utils.py index c4d2cc9416..008f89fa80 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -10,11 +10,13 @@ from apex.multi_tensor_apply import multi_tensor_applier import amp_C -from megatron import get_args -from megatron import get_adlr_autoresume -from megatron import mpu -from megatron.model.module import param_is_not_shared +from megatron import ( + get_args, + get_adlr_autoresume, +) +from megatron.core import mpu from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate +from megatron.model.module import param_is_not_shared def unwrap_model(model, module_instances=(torchDDP)): diff --git a/pretrain_bert.py b/pretrain_bert.py index 18d84f361b..3edbd6fc8d 100644 --- a/pretrain_bert.py +++ b/pretrain_bert.py @@ -10,7 +10,7 @@ from megatron import get_args from megatron import print_rank_0 from megatron import get_timers -from megatron import mpu +from megatron.core import tensor_parallel from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.model import BertModel, ModelType from megatron.training import pretrain @@ -46,7 +46,7 @@ def get_batch(data_iterator): data = next(data_iterator) else: data = None - data_b = mpu.broadcast_data(keys, data, datatype) + data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. tokens = data_b['text'].long() diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 77bf24a8e6..db9effff72 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -8,8 +8,7 @@ from megatron import print_rank_0 from megatron import get_timers from megatron import get_tokenizer -from megatron import mpu -from megatron import core +from megatron.core import tensor_parallel from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.model import GPTModel, ModelType from megatron.training import pretrain @@ -43,7 +42,7 @@ def get_batch(data_iterator): data = next(data_iterator) else: data = None - data_b = core.tensor_parallel.broadcast_data(keys, data, datatype) + data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() diff --git a/pretrain_ict.py b/pretrain_ict.py index e0836181d0..c942b0c29d 100644 --- a/pretrain_ict.py +++ b/pretrain_ict.py @@ -12,7 +12,7 @@ from megatron import get_args from megatron import print_rank_0 from megatron import get_timers -from megatron import mpu +from megatron.core import mpu from megatron.data.biencoder_dataset_utils import get_ict_batch from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.model import ModelType diff --git a/pretrain_t5.py b/pretrain_t5.py index b7c94fcd4d..11832cbcd0 100644 --- a/pretrain_t5.py +++ b/pretrain_t5.py @@ -9,9 +9,9 @@ from megatron import ( get_args, get_timers, - mpu, print_rank_0 ) +from megatron.core import tensor_parallel from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.model import T5Model, ModelType from megatron.training import pretrain @@ -80,7 +80,7 @@ def get_batch(data_iterator): data = next(data_iterator) else: data = None - data_b = mpu.broadcast_data(keys, data, datatype) + data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. tokens_enc = data_b['text_enc'].long() diff --git a/pretrain_vision_classify.py b/pretrain_vision_classify.py index a77a8e1852..b9d0711007 100644 --- a/pretrain_vision_classify.py +++ b/pretrain_vision_classify.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F from functools import partial -from megatron import get_args, get_timers, mpu, print_rank_0 +from megatron import get_args, get_timers, print_rank_0 from megatron.data.vit_dataset import build_train_valid_datasets from megatron.model import ModelType from megatron.model.vision.classification import VitClassificationModel diff --git a/pretrain_vision_dino.py b/pretrain_vision_dino.py index 2eb5f9d76a..7095728b77 100644 --- a/pretrain_vision_dino.py +++ b/pretrain_vision_dino.py @@ -6,7 +6,7 @@ import numpy as np import torch.distributed as dist from functools import partial -from megatron import get_args, get_timers, mpu, print_rank_0 +from megatron import get_args, get_timers, print_rank_0 from megatron.data.vit_dataset import build_train_valid_datasets from megatron.model.vision.dino import DINOPretrainModel from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank diff --git a/pretrain_vision_inpaint.py b/pretrain_vision_inpaint.py index 191a263622..4d26d9f134 100644 --- a/pretrain_vision_inpaint.py +++ b/pretrain_vision_inpaint.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F from functools import partial -from megatron import get_args, get_timers, mpu, print_rank_0, print_rank_last +from megatron import get_args, get_timers, print_rank_0, print_rank_last from megatron.data.vit_dataset import build_train_valid_datasets from megatron.model.vision.inpainting import VitInpaintingModel from megatron.model.vision.inpainting import MitInpaintingModel diff --git a/tasks/eval_utils.py b/tasks/eval_utils.py index d7d932e948..6b29db345f 100644 --- a/tasks/eval_utils.py +++ b/tasks/eval_utils.py @@ -10,7 +10,7 @@ from megatron import get_args from megatron import print_rank_last, is_last_rank -from megatron import mpu +from megatron.core import mpu from megatron.schedules import get_forward_backward_func from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import process_batch diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py index c9558a7b19..5ea3dc1830 100644 --- a/tasks/finetune_utils.py +++ b/tasks/finetune_utils.py @@ -9,7 +9,7 @@ from megatron import get_args, get_num_microbatches from megatron import print_rank_0 from megatron import get_timers -from megatron import mpu +from megatron.core import mpu from megatron.checkpointing import load_checkpoint from megatron.checkpointing import save_checkpoint from megatron.model import ModelType diff --git a/tasks/glue/finetune.py b/tasks/glue/finetune.py index 4bed99c4e0..0c31b90470 100644 --- a/tasks/glue/finetune.py +++ b/tasks/glue/finetune.py @@ -5,7 +5,6 @@ from megatron import get_args from megatron import print_rank_0 from megatron import get_tokenizer -from megatron import mpu from megatron.model.classification import Classification from tasks.eval_utils import accuracy_func_provider from tasks.finetune_utils import finetune diff --git a/tasks/msdp/prompt.py b/tasks/msdp/prompt.py index 00591cfdf9..a4e777e0b8 100644 --- a/tasks/msdp/prompt.py +++ b/tasks/msdp/prompt.py @@ -6,10 +6,10 @@ import torch import requests from nltk import word_tokenize -from megatron import mpu from megatron import get_args from megatron import print_rank_0 from megatron import get_tokenizer +from megatron.core import mpu from megatron.model import GPTModel from megatron.training import get_model from megatron.checkpointing import load_checkpoint diff --git a/tasks/orqa/supervised/eval_utils.py b/tasks/orqa/supervised/eval_utils.py index 8aebadfa59..02966362c9 100644 --- a/tasks/orqa/supervised/eval_utils.py +++ b/tasks/orqa/supervised/eval_utils.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader from megatron import get_args, print_rank_0 -from megatron import mpu +from megatron.core import mpu from megatron.utils import average_losses_across_data_parallel_group from tasks.finetune_utils import build_data_loader diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py index 4510b529ab..c186dcc518 100644 --- a/tasks/orqa/supervised/finetune.py +++ b/tasks/orqa/supervised/finetune.py @@ -9,8 +9,8 @@ import torch import torch.nn.functional as F -from megatron import get_args, get_timers, get_tokenizer -from megatron import mpu, print_rank_0 +from megatron import get_args, get_timers, get_tokenizer, print_rank_0 +from megatron.core import mpu from megatron.indexer import IndexBuilder from megatron.model.biencoder_model import biencoder_model_provider from megatron.utils import average_losses_across_data_parallel_group diff --git a/tasks/orqa/unsupervised/nq.py b/tasks/orqa/unsupervised/nq.py index 5ef95af1f5..56fd77c12c 100644 --- a/tasks/orqa/unsupervised/nq.py +++ b/tasks/orqa/unsupervised/nq.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader from torch.utils.data import Dataset, BatchSampler -from megatron import print_rank_0, get_args, get_tokenizer, mpu +from megatron import print_rank_0, get_args, get_tokenizer from megatron.data.biencoder_dataset_utils import make_attention_mask def get_nq_dataset(qa_data, split): diff --git a/tasks/race/finetune.py b/tasks/race/finetune.py index bbcb81a6e7..18b3ff919d 100644 --- a/tasks/race/finetune.py +++ b/tasks/race/finetune.py @@ -5,7 +5,6 @@ from megatron import get_args from megatron import print_rank_0 from megatron import get_tokenizer -from megatron import mpu from megatron.model.multiple_choice import MultipleChoice from tasks.eval_utils import accuracy_func_provider from tasks.finetune_utils import finetune diff --git a/tasks/vision/classification/eval_utils.py b/tasks/vision/classification/eval_utils.py index 9f4be05ae1..d3eaec4850 100644 --- a/tasks/vision/classification/eval_utils.py +++ b/tasks/vision/classification/eval_utils.py @@ -9,7 +9,7 @@ from megatron import get_args from megatron import print_rank_0, print_rank_last -from megatron import mpu +from megatron.core import mpu from megatron.schedules import get_forward_backward_func from tasks.vision.finetune_utils import build_data_loader from tasks.vision.finetune_utils import process_batch diff --git a/tasks/vision/finetune_utils.py b/tasks/vision/finetune_utils.py index 2c1d8c3dad..3b73707732 100644 --- a/tasks/vision/finetune_utils.py +++ b/tasks/vision/finetune_utils.py @@ -7,7 +7,8 @@ from megatron import get_args from megatron import print_rank_0 from megatron import get_timers -from megatron import mpu, utils +from megatron import utils +from megatron.core import mpu from megatron.checkpointing import load_checkpoint from megatron.checkpointing import save_checkpoint from megatron.training import evaluate_and_print_results diff --git a/tasks/vision/segmentation/finetune_segformer.py b/tasks/vision/segmentation/finetune_segformer.py index 1dee971505..10a4085be4 100644 --- a/tasks/vision/segmentation/finetune_segformer.py +++ b/tasks/vision/segmentation/finetune_segformer.py @@ -7,7 +7,8 @@ import torch.nn.functional as F from functools import partial from megatron import get_args, get_timers -from megatron import mpu, print_rank_0, print_rank_last +from megatron import print_rank_0, print_rank_last +from megatron.core import mpu from tasks.vision.finetune_utils import finetune from tasks.vision.finetune_utils import build_data_loader from megatron.utils import average_losses_across_data_parallel_group diff --git a/tasks/vision/segmentation/finetune_setr.py b/tasks/vision/segmentation/finetune_setr.py index 606bf7f523..7f3208d09a 100644 --- a/tasks/vision/segmentation/finetune_setr.py +++ b/tasks/vision/segmentation/finetune_setr.py @@ -6,7 +6,8 @@ import torch.nn.functional as F from functools import partial from megatron import get_args, get_timers -from megatron import mpu, print_rank_0, print_rank_last +from megatron import print_rank_0, print_rank_last +from megatron.core import mpu from tasks.vision.finetune_utils import finetune from tasks.vision.finetune_utils import build_data_loader from megatron.utils import average_losses_across_data_parallel_group diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index 43db544d0b..d76039673a 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -9,7 +9,7 @@ from megatron import get_args from megatron import print_rank_0, is_last_rank from megatron import get_tokenizer -from megatron import mpu +from megatron.core import mpu from megatron.checkpointing import load_checkpoint from megatron.model import GPTModel from megatron.training import get_model @@ -93,7 +93,7 @@ def forward_step(batch, model, eval_metric): if mpu.is_pipeline_last_stage(): # For loss, return the unreduced loss. if eval_metric == 'loss': - losses = mpu.vocab_parallel_cross_entropy( + losses = mpu.tensor_parallel.vocab_parallel_cross_entropy( output.contiguous().float(), labels.contiguous()) loss = torch.sum( losses.view(-1) * loss_mask.contiguous().view(-1).float()) diff --git a/tools/checkpoint_loader_megatron.py b/tools/checkpoint_loader_megatron.py index 64dfd8be79..977255335a 100644 --- a/tools/checkpoint_loader_megatron.py +++ b/tools/checkpoint_loader_megatron.py @@ -30,7 +30,8 @@ def _load_checkpoint(queue, args): from megatron.global_vars import set_args, set_global_variables from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint from megatron.model import ModelType, module - from megatron import mpu, fused_kernels + from megatron.core import mpu + from megatron import fused_kernels except ModuleNotFoundError: print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") queue.put("exit") @@ -99,7 +100,7 @@ def get_models(count, dtype, pre_process, post_process): nonlocal consumed_valid_samples models = [] for rank in range(count): - mpu.initialize.set_tensor_model_parallel_rank(rank) + mpu.parallel_state.set_tensor_model_parallel_rank(rank) model_ = [model_provider(pre_process, post_process).to(dtype)] margs.consumed_train_samples = 0 margs.consumed_valid_samples = 0 @@ -123,8 +124,8 @@ def get_models(count, dtype, pre_process, post_process): exit(1) set_global_variables(margs) - mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) - mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) + mpu.parallel_state.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) + mpu.parallel_state.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) fused_kernels.load(margs) # Get true (non-padded) vocab size @@ -162,7 +163,7 @@ def get_models(count, dtype, pre_process, post_process): md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by # Get first pipe stage - mpu.initialize.set_pipeline_model_parallel_rank(0) + mpu.parallel_state.set_pipeline_model_parallel_rank(0) post_process = pp_size == 1 models = get_models(tp_size, md.params_dtype, True, post_process) @@ -188,7 +189,7 @@ def queue_put(name, msg): total_layer_num = 0 for pp_rank in range(pp_size): if pp_rank > 0: - mpu.initialize.set_pipeline_model_parallel_rank(pp_rank) + mpu.parallel_state.set_pipeline_model_parallel_rank(pp_rank) post_process = pp_rank == pp_size - 1 models = get_models(tp_size, md.params_dtype, False, post_process) for layer_num in range(len(models[0].language_model.encoder.layers)): diff --git a/tools/checkpoint_saver_megatron.py b/tools/checkpoint_saver_megatron.py index 2695a00ac8..f3a5145a3b 100644 --- a/tools/checkpoint_saver_megatron.py +++ b/tools/checkpoint_saver_megatron.py @@ -34,7 +34,8 @@ def save_checkpoint(queue, args): from megatron.global_vars import set_global_variables, get_args from megatron.model import ModelType from megatron.tokenizer.tokenizer import _vocab_size_with_padding - from megatron import mpu, fused_kernels + from megatron import fused_kernels + from megatron.core import mpu except ModuleNotFoundError: print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") exit(1) @@ -152,10 +153,10 @@ def get_models(count, dtype, pre_process, post_process): return models # fake initializing distributed - mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size) - mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size) - mpu.initialize.set_tensor_model_parallel_rank(0) - mpu.initialize.set_pipeline_model_parallel_rank(0) + mpu.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size) + mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size) + mpu.set_tensor_model_parallel_rank(0) + mpu.set_pipeline_model_parallel_rank(0) fused_kernels.load(margs) # Embeddings @@ -197,7 +198,7 @@ def get_models(count, dtype, pre_process, post_process): out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0) # Make models for first pipeline stage and fill in embeddings - mpu.initialize.set_pipeline_model_parallel_rank(0) + mpu.set_pipeline_model_parallel_rank(0) post_process = args.target_pipeline_parallel_size == 1 models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process) for tp_rank, model in enumerate(models): @@ -211,7 +212,7 @@ def get_models(count, dtype, pre_process, post_process): for pp_rank in range(args.target_pipeline_parallel_size): # For later pipeline parallel ranks, make the new models if pp_rank > 0: - mpu.initialize.set_pipeline_model_parallel_rank(pp_rank) + mpu.set_pipeline_model_parallel_rank(pp_rank) post_process = pp_rank == args.target_pipeline_parallel_size - 1 models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process) @@ -317,6 +318,6 @@ def get_models(count, dtype, pre_process, post_process): print("ERROR: got some more data but was expecting to be done") for tp_rank in range(args.target_tensor_parallel_size): - mpu.initialize.set_tensor_model_parallel_rank(tp_rank) + mpu.set_tensor_model_parallel_rank(tp_rank) save_checkpoint(md.iteration, [models[tp_rank]], None, None) print("Done!") diff --git a/tools/run_text_generation_server.py b/tools/run_text_generation_server.py index 1cdc8e23d4..3fdd27bea0 100644 --- a/tools/run_text_generation_server.py +++ b/tools/run_text_generation_server.py @@ -8,7 +8,7 @@ import socket from megatron import get_args from megatron import print_rank_0 -from megatron import mpu +from megatron.core import mpu from megatron.checkpointing import load_checkpoint from megatron.initialize import initialize_megatron from megatron.model import GPTModel