Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pipeline parallelism lead to slower speeds #903

Closed
cdj0311 opened this issue Apr 25, 2023 · 14 comments
Closed

pipeline parallelism lead to slower speeds #903

cdj0311 opened this issue Apr 25, 2023 · 14 comments

Comments

@cdj0311
Copy link

cdj0311 commented Apr 25, 2023

hi,
I trained a 20B model with 128 GPUs, the speed very slow "pipe-parallel-size=4, model-parallel-size=2", the TFLOPS=33, but I just modified "pipe-parallel-size=1, model-parallel-size=2", the TFLOPS=170, why is this happening?

@StellaAthena
Copy link
Member

Parallelism settings need to be tuned to your specific hardware to get the best performance. If you're getting 170 TFLOPS on your system without pipeline parallelism, then not using pipeline parallelism seems like the best option :) There's no specific reason you should desire to use it, it's just helpful for some hardware set-ups / model configurations.

@cdj0311
Copy link
Author

cdj0311 commented Apr 26, 2023

Parallelism settings need to be tuned to your specific hardware to get the best performance. If you're getting 170 TFLOPS on your system without pipeline parallelism, then not using pipeline parallelism seems like the best option :) There's no specific reason you should desire to use it, it's just helpful for some hardware set-ups / model configurations.

Actually, I plan to train a larger model (175B), the pipline parallel are necessary.

@StellaAthena
Copy link
Member

StellaAthena commented Apr 26, 2023

Parallelism settings need to be tuned to your specific hardware to get the best performance. If you're getting 170 TFLOPS on your system without pipeline parallelism, then not using pipeline parallelism seems like the best option :) There's no specific reason you should desire to use it, it's just helpful for some hardware set-ups / model configurations.

Actually, I plan to train a larger model (175B), the pipline parallel are necessary.

I see. Can you share some technical details (feel free to email me if you don’t want to do so publicly) about your computing cluster? Also the exact configuration file you’re using.

It seems likely that one of three things are happening:

  1. Your set-up isn’t effectively utilizing tensor cores (this would happen if the per-GPU matmults have side lengths not divisible by large powers of two)
  2. The default autoparallelization isn’t working on your hardware for some reason
  3. Your interconnect is introducing substantial bottlenecks and/or there is an anomalously large pipeline bubble.

@cdj0311
Copy link
Author

cdj0311 commented Apr 27, 2023

Parallelism settings need to be tuned to your specific hardware to get the best performance. If you're getting 170 TFLOPS on your system without pipeline parallelism, then not using pipeline parallelism seems like the best option :) There's no specific reason you should desire to use it, it's just helpful for some hardware set-ups / model configurations.

Actually, I plan to train a larger model (175B), the pipline parallel are necessary.

I see. Can you share some technical details (feel free to email me if you don’t want to do so publicly) about your computing cluster? Also the exact configuration file you’re using.

It seems likely that one of three things are happening:

  1. Your set-up isn’t effectively utilizing tensor cores (this would happen if the per-GPU matmults have side lengths not divisible by large powers of two)
  2. The default autoparallelization isn’t working on your hardware for some reason
  3. Your interconnect is introducing substantial bottlenecks and/or there is an anomalously large pipeline bubble.

GPU: A100-SXM-80G, nodes=16, per_node_gpus=8

config:

{
  "vocab-file": "tokenizer.json",
  "save": "checkpoints",
  "load": "checkpoints",
  "global-num-gpus": 128,
  "attention_config": [[["flash"], 96]],
  "data-path": "./data_text_document",
  # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
  # across the node boundaries )
  "pipe-parallel-size": 16,
  "model-parallel-size": 8,

  # model settings
  "num-layers": 96,
  "hidden-size": 12288,
  "num-attention-heads": 96,
  "seq-length": 4096,
  "max-position-embeddings": 4096,
  "norm": "layernorm",
  "pos-emb": "rotary",
  "rotary_pct": 0.25,
  "no-weight-tying": true,
  "gpt_j_residual": true,
  "output_layer_parallelism": "column",
  "scaled-upper-triang-masked-softmax-fusion": true,
  "bias-gelu-fusion": true,

  # init methods
  "init_method": "small_init",
  "output_layer_init_method": "wang_init",

  # optimizer settings
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.97e-4,
      "betas": [0.9, 0.95],
      "eps": 1.0e-8,
      }
      },

  "min_lr": 0.97e-5,

  # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
  "zero_optimization": {
  "stage": 1,
  "allgather_partitions": True,
  "allgather_bucket_size": 1260000000,
  "overlap_comm": True,
  "reduce_scatter": True,
  "reduce_bucket_size": 1260000000,
  "contiguous_gradients": True,
  },

  # batch / data settings (assuming 96 GPUs)
  "train_micro_batch_size_per_gpu": 8,
  "gradient_accumulation_steps": 1,
  "data-impl": "mmap",
  "split": "995,4,1",

  # activation checkpointing
  "checkpoint-activations": true,
  "checkpoint-num-layers": 1,
  "partition-activations": false,
  "synchronize-each-layer": true,

  # regularization
  "gradient_clipping": 1.0,
  "weight-decay": 0.01,
  "hidden-dropout": 0,
  "attention-dropout": 0,

  # precision settings
  "fp16": {
    "fp16": true,
    "enabled": true,
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 12,
    "hysteresis": 2,
    "min_loss_scale": 1
    },

  # misc. training settings
  "train-iters": 150000,
  "lr-decay-iters": 150000,

  "distributed-backend": "nccl",
  "lr-decay-style": "cosine",
  "warmup": 0.01,
  "checkpoint-factor": 500, # this variable previously called `save-interval`
  "eval-interval": 1000,
  "eval-iters": 10,

  # logging
  "log-interval": 1,
  "steps_per_print": 10,
  "wall_clock_breakdown": false,

  ### NEW DATA: ####
  "tokenizer_type": "HFTokenizer",
  "tensorboard-dir": "./tensorboard",
  "log-dir": "./logs",

}

@cdj0311
Copy link
Author

cdj0311 commented Apr 27, 2023

@StellaAthena Have you trained 175B model? What the config is?

@StellaAthena
Copy link
Member

@ShivanshuPurohit can you help with this? You have the most experience running GPT-NeoX at very large scales.

@Quentin-Anthony
Copy link
Member

I've run 175B before. Sharing the exact config wouldn't be helpful because it's cluster-specific, but all the fundamental values are in https://github.com/EleutherAI/gpt-neox/blob/main/configs/175B.yml

Further, I made the following key changes:

  • Make the model shorter and wider for improved efficiency:
-   "num-layers": 96,
-   "hidden-size": 12288,
-   "num-attention-heads": 96,
+   "num-layers": 70,
+   "hidden-size": 14336,
+   "num-attention-heads": 112,

Model-parallelism within nodes and pipeline parallelism across nodes:

-   "pipe-parallel-size": 1,
-   "model-parallel-size": 1,
+   "pipe-parallel-size": 4,
+   "model-parallel-size": 8,

You also need to tune the batch size to be large enough to fill GPU memory (and improve efficiency), yet small enough to preserve accuracy. We find this range to be 2M-8M tokens.

I see that your seqlen is 4096, so your parallelism settings will be different. I recommend you make the following configuration changes:

  • Reduce pipeline parallelism degree and the train-batch size. Your global batch size (train_batch_size = (train_batch_size_per_gpu) * (num_gpus / (pipe-size * tensor-parallel size)) * gas will therefore remain about the same yet you'll spend less time in communication (more data-parallel replicas with a smaller batch size rather than a few large replicas with a bigger batch size). I also recommend setting "partition-activations": true, which will split the activation memory across your tensor-parallel GPUs.

@Quentin-Anthony
Copy link
Member

hi, I trained a 20B model with 128 GPUs, the speed very slow "pipe-parallel-size=4, model-parallel-size=2", the TFLOPS=33, but I just modified "pipe-parallel-size=1, model-parallel-size=2", the TFLOPS=170, why is this happening?

Increasing the pipeline parallel degree (1 --> 4) will increase communication significantly. You should reduce parallelism wherever possible.

@StellaAthena
Copy link
Member

hi, I trained a 20B model with 128 GPUs, the speed very slow "pipe-parallel-size=4, model-parallel-size=2", the TFLOPS=33, but I just modified "pipe-parallel-size=1, model-parallel-size=2", the TFLOPS=170, why is this happening?

Increasing the pipeline parallel degree (1 --> 4) will increase communication significantly. You should reduce parallelism wherever possible.

@cdj0311 Related to this, rereading your 175B config file I see you’ve specified PP = 16. That’s a ton of pipeline parallelism and almost certainly a major problem. Did you try 4 or 8 before going to 16? Some back-of-the-envelope calculations makes it seem like 16 isn’t anywhere close to necessary.

@cdj0311
Copy link
Author

cdj0311 commented May 3, 2023

@Quentin-Anthony @StellaAthena I modified 175B config as follows:

"pipe-parallel-size": 2,
"model-parallel-size": 8,
"num-layers": 70,
"hidden-size": 14336,
"num-attention-heads": 112,
"partition-activations": true,
"train_micro_batch_size_per_gpu": 8,

The TFLOPS increased from 25 to 46, still very low, but the bloom176B (megatron+deepspeed) get 140+ TFLOPS. Is there any other way not to use pipeline parallelism? Or are there any other ways to improve GPU utilization?

@StellaAthena
Copy link
Member

Setting PP=1 will not use pipeline parallelism, and setting PP=0 will not use the entire module we use for pipeline parallelism, instead using Torch.Sequential.

Looking at the BLOOM slurm commands I see that they’re using a curious arrangement of parallelism:

TP=4
PP=12
DP=8
MBS=2

When you say you get 140 TFLOPs using their code, do you mean you’re using their parallelism settings as well? Or just the codebase with custom parallelism settings?

@cdj0311
Copy link
Author

cdj0311 commented May 4, 2023

BLOOM I just find the codebase description, so I'm not sure it running my cluster can get the same results.
With GPT-NeoX, 175B must set PP>=2 (MP=8), but it can be lead slow, I really don't know how to set these parameters.
@Quentin-Anthony How much TFLOPS can your 175B model achieve?

@Quentin-Anthony
Copy link
Member

I can get about 120 TFLOPs on 64 nodes (512 A100 40GB GPUs) with a 175B on seqlen=2048. It's system-dependent so don't put too much stock into the exact number, but you should be able to get above 100 TFLOPs.

Your micro batch size is very high. Can you reduce it to 1 and also reduce pp to 1? If you can remove those inter-node comms that'd help a lot.

@cdj0311
Copy link
Author

cdj0311 commented May 5, 2023

I can get about 120 TFLOPs on 64 nodes (512 A100 40GB GPUs) with a 175B on seqlen=2048. It's system-dependent so don't put too much stock into the exact number, but you should be able to get above 100 TFLOPs.

Your micro batch size is very high. Can you reduce it to 1 and also reduce pp to 1? If you can remove those inter-node comms that'd help a lot.

Setting pp=1, mp=8, batch_size=1, seqlen=2048, get GPU OOM.

@cdj0311 cdj0311 closed this as completed May 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants