Skip to content

Commit

Permalink
Merge branch 'pipeline_parallel_main' into 'main'
Browse files Browse the repository at this point in the history
Pipeline parallelism and inter-layer model parallelism implementation

See merge request ADLR/megatron-lm!159
  • Loading branch information
deepakn94 committed Dec 21, 2020
2 parents 3aacd95 + 6e83649 commit 9b174da
Show file tree
Hide file tree
Showing 67 changed files with 3,078 additions and 1,085 deletions.
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel, and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.
[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel (tensor and pipeline), and multinode training of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.

Using our GPT-2 model we achieve a perplexity of 10.8 on the WikiText-103 dataset (improving SOTA from 15.8) and an accuracy of 66.5% on the LAMBADA datasets. For BERT training, we swapped the position of the layer normalization and the residual connection in the model architecture (similar to GPT-2 architucture), which allowed the models to continue to improve as they were scaled up. Our BERT model with 3.9 billion parameters reaches a loss of 1.16, SQuAD 2.0 F1-score of 91.7, and RACE accuracy of 90.9%.

Expand Down Expand Up @@ -218,7 +218,12 @@ These scripts use the PyTorch distributed launcher for distributed training. As

The two tiers of parallelism are data and model parallelism. First, we facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model parallel sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time.

Second, we developed a simple and efficient intra-layer model parallel approach. To use model parallelism, add the `--model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. With `WORLD_SIZE` GPUs and `MP_SIZE` model parallel size, `WORLD_SIZE`/`MP_SIZE` GPUs will be used for data parallelism. The default value for `--model-parallel-size` is 1, which will not implement model parallelism.
Second, we developed a simple and efficient two-dimensional model-parallel approach. To use tensor model parallelism (splitting execution of a single transformer module over multiple GPUs), add the `--tensor-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. To use pipeline model parallelism (sharding the transformer modules into stages with an equal number of transformer modules on each stage, and then pipelining execution by breaking the batch into smaller microbatches), use the `--pipeline-model-parallel-size` flag to specify the number of stages to split the model into (e.g., splitting a model with 24 transformer layers across 4 stages would mean each stage gets 6 transformer layers each). The number of microbatches in a per-pipeline minibatch is controlled by the `--num-microbatches-in-minibatch` argument. With `WORLD_SIZE` GPUs, `TENSOR_MP_SIZE` tensor-model-parallel size, `PIPELINE_MP_SIZE` pipeline-model-parallel-size, `WORLD_SIZE`/(`TENSOR_MP_SIZE` * `PIPELINE_MP_SIZE`) GPUs will be used for data parallelism. The default values for `--tensor-model-parallel-size` and `--pipeline-model-parallel-size` is 1, which will not implement either form of model parallelism.

We have examples of how to use these two different forms of model parallelism in these scripts:
`bash examples/pretrain_bert_distributed_with_mp.sh`

`bash examples/pretrain_gpt2_distributed_with_mp.sh`

Other than these minor changes, the distributed training is identical to the training on a single GPU.

Expand All @@ -245,7 +250,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_bert.py \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--model-parallel-size $MP_SIZE \
--tensor-model-parallel-size $MP_SIZE \
--DDP-impl torch
</pre>

Expand All @@ -269,7 +274,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_gpt2.py \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--model-parallel-size $MP_SIZE \
--tensor-model-parallel-size $MP_SIZE \
--DDP-impl torch

</pre>
Expand Down Expand Up @@ -362,14 +367,14 @@ We provide several command line arguments, detailed in the scripts listed below,
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.

<pre>
MODEL_PARALLEL_SIZE=2
TENSOR_MODEL_PARALLEL_SIZE=2

VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m

WORLD_SIZE=$MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \
--model-parallel-size $MODEL_PARALLEL_SIZE \
--tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \
--num-layers 24 \
Expand Down Expand Up @@ -488,7 +493,7 @@ Further command line arguments are described in the source file [`main.py`](./ta
## BERT Task Evaluation
<a id="race-evaluation"></a>
### RACE Evaluation
The following script finetunes the BERT model for evaluation on the [RACE dataset](http:https://www.cs.cmu.edu/~glai1/data/race/). The `TRAIN_DATA` and `VALID_DATA` directory contain the RACE dataset as separate `.txt` files.
The following script finetunes the BERT model for evaluation on the [RACE dataset](http:https://www.cs.cmu.edu/~glai1/data/race/). The `TRAIN_DATA` and `VALID_DATA` directory contain the RACE dataset as separate `.txt` files. Note that for RACE, the batch size is the number of RACE query's to evaluate. Since each RACE query has four samples, the effective batch size passed through the model will be four times the batch size specified on the command line.

<pre>
TRAIN_DATA="data/RACE/train/middle"
Expand Down
2 changes: 1 addition & 1 deletion examples/evaluate_zeroshot_gpt2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--load $CHECKPOINT \
--model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
Expand Down
2 changes: 1 addition & 1 deletion examples/finetune_mnli_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \
--epochs 5 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
Expand Down
2 changes: 1 addition & 1 deletion examples/finetune_race_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--vocab-file $VOCAB_FILE \
--epochs 3 \
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \
--model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
Expand Down
2 changes: 1 addition & 1 deletion examples/generate_text.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt

python tools/generate_samples_gpt2.py \
--model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--load $CHECKPOINT_PATH \
Expand Down
6 changes: 3 additions & 3 deletions examples/merge_mp_bert.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/bin/bash

MODEL_PARALLEL_SIZE=2
TENSOR_MODEL_PARALLEL_SIZE=2

VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m

WORLD_SIZE=$MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \
--model-parallel-size $MODEL_PARALLEL_SIZE \
--tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \
--num-layers 24 \
Expand Down
1 change: 0 additions & 1 deletion examples/pretrain_bert.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ python pretrain_bert.py \
--eval-interval 1000 \
--eval-iters 10 \
--fp16

2 changes: 1 addition & 1 deletion examples/pretrain_bert_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $

python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \
--model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
Expand Down
46 changes: 46 additions & 0 deletions examples/pretrain_bert_distributed_with_mp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash

GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

DATA_PATH=<Specify path and file prefix>_text_sentence
CHECKPOINT_PATH=<Specify path>

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \
--tensor-model-parallel-size 2 \
--pipeline-model-parallel-size 2 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 2 \
--num-microbatches-in-minibatch 2 \
--seq-length 512 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--vocab-file bert-vocab.txt \
--data-impl mmap \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--min-lr 1.0e-5 \
--lr-decay-iters 990000 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--log-interval 100 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--fp16
3 changes: 0 additions & 3 deletions examples/pretrain_gpt2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,3 @@ python pretrain_gpt2.py \
--eval-interval 1000 \
--eval-iters 10 \
--fp16


set +x
6 changes: 1 addition & 5 deletions examples/pretrain_gpt2_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $

python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \
--model-parallel-size 1 \
--tensor-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
Expand Down Expand Up @@ -46,7 +46,3 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--eval-interval 1000 \
--eval-iters 10 \
--fp16



set +x
50 changes: 50 additions & 0 deletions examples/pretrain_gpt2_distributed_with_mp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#! /bin/bash

# Runs the "345M" parameter model

GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

DATA_PATH=<Specify path and file prefix>_text_document
CHECKPOINT_PATH=<Specify path>

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \
--tensor-model-parallel-size 2 \
--pipeline-model-parallel-size 2 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 4 \
--num-microbatches-in-minibatch 2 \
--seq-length 1024 \
--max-position-embeddings 1024 \
--train-iters 500000 \
--lr-decay-iters 320000 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--vocab-file gpt2-vocab.json \
--merge-file gpt2-merges.txt \
--data-impl mmap \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.00015 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--checkpoint-activations \
--log-interval 100 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--fp16
17 changes: 16 additions & 1 deletion megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,31 @@
)

from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .initialize import initialize_megatron

def print_rank_0(message):
"""If distributed is initialized print only on rank 0."""
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)

def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)

def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
Loading

0 comments on commit 9b174da

Please sign in to comment.