-
Notifications
You must be signed in to change notification settings - Fork 982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pipeline parallelism lead to slower speeds #903
Comments
Parallelism settings need to be tuned to your specific hardware to get the best performance. If you're getting 170 TFLOPS on your system without pipeline parallelism, then not using pipeline parallelism seems like the best option :) There's no specific reason you should desire to use it, it's just helpful for some hardware set-ups / model configurations. |
Actually, I plan to train a larger model (175B), the pipline parallel are necessary. |
I see. Can you share some technical details (feel free to email me if you don’t want to do so publicly) about your computing cluster? Also the exact configuration file you’re using. It seems likely that one of three things are happening:
|
GPU: A100-SXM-80G, nodes=16, per_node_gpus=8 config:
|
@StellaAthena Have you trained 175B model? What the config is? |
@ShivanshuPurohit can you help with this? You have the most experience running GPT-NeoX at very large scales. |
I've run 175B before. Sharing the exact config wouldn't be helpful because it's cluster-specific, but all the fundamental values are in https://github.com/EleutherAI/gpt-neox/blob/main/configs/175B.yml Further, I made the following key changes:
Model-parallelism within nodes and pipeline parallelism across nodes:
You also need to tune the batch size to be large enough to fill GPU memory (and improve efficiency), yet small enough to preserve accuracy. We find this range to be 2M-8M tokens. I see that your seqlen is 4096, so your parallelism settings will be different. I recommend you make the following configuration changes:
|
Increasing the pipeline parallel degree (1 --> 4) will increase communication significantly. You should reduce parallelism wherever possible. |
@cdj0311 Related to this, rereading your 175B config file I see you’ve specified PP = 16. That’s a ton of pipeline parallelism and almost certainly a major problem. Did you try 4 or 8 before going to 16? Some back-of-the-envelope calculations makes it seem like 16 isn’t anywhere close to necessary. |
@Quentin-Anthony @StellaAthena I modified 175B config as follows:
The TFLOPS increased from 25 to 46, still very low, but the bloom176B (megatron+deepspeed) get 140+ TFLOPS. Is there any other way not to use pipeline parallelism? Or are there any other ways to improve GPU utilization? |
Setting PP=1 will not use pipeline parallelism, and setting PP=0 will not use the entire module we use for pipeline parallelism, instead using Looking at the BLOOM slurm commands I see that they’re using a curious arrangement of parallelism:
When you say you get 140 TFLOPs using their code, do you mean you’re using their parallelism settings as well? Or just the codebase with custom parallelism settings? |
BLOOM I just find the codebase description, so I'm not sure it running my cluster can get the same results. |
I can get about 120 TFLOPs on 64 nodes (512 A100 40GB GPUs) with a 175B on seqlen=2048. It's system-dependent so don't put too much stock into the exact number, but you should be able to get above 100 TFLOPs. Your micro batch size is very high. Can you reduce it to 1 and also reduce pp to 1? If you can remove those inter-node comms that'd help a lot. |
Setting pp=1, mp=8, batch_size=1, seqlen=2048, get GPU OOM. |
hi,
I trained a 20B model with 128 GPUs, the speed very slow "pipe-parallel-size=4, model-parallel-size=2", the TFLOPS=33, but I just modified "pipe-parallel-size=1, model-parallel-size=2", the TFLOPS=170, why is this happening?
The text was updated successfully, but these errors were encountered: