Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Details about pipeline parallelism implementation in DeepSpeed #1110

Closed
ParamsRaman opened this issue May 26, 2021 · 6 comments
Closed

Details about pipeline parallelism implementation in DeepSpeed #1110

ParamsRaman opened this issue May 26, 2021 · 6 comments

Comments

@ParamsRaman
Copy link

ParamsRaman commented May 26, 2021

Hi,

I had some questions about the pipeline parallelism implementation in DeepSpeed. Can someone help shed some information on the following?

  1. From among the following types of pipeline scheduling, which one does DeepSpeed implement in its code?
    (a) Figure 2 in PipeDream-2BW paper (https://arxiv.org/pdf/2006.09503.pdf)
    (b) PipeDream-Flush (1F1B) schedule mentioned in Figure 4 (top) in Megatron 3D paper (https://arxiv.org/pdf/2104.04473.pdf)
    (c) Interleaved 1F1B schedule mentioned in Figure 4 (bottom) in Megatron 3D paper (https://arxiv.org/pdf/2104.04473.pdf)

  2. What communication collective primitives are used while implementing pipeline parallelism?

  3. runtime/pipe/engine.py mentions following comment.
    Note: ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism.
    Is this still true in the recent version of DeepSpeed?

@ShadenSmith
Copy link
Contributor

Hi @ParamsRaman,

  1. DeepSpeed's pipeline engine schedule is equivalent to the 1F1B schedule (option b, or PipeDream-Flush). It was developed independently and open sourced before that naming scheme. Here's an illustration of our pipeline schedule. The trained model will be the same as a model trained with pure data parallelism and the same effective batch size, modulo things like initialization or PRNG effects such as dropout.

    We chose this schedule to facilitate the largest scale models in which storing multiple weight buffers might be prohibitive and we want to first target "exact" approaches without convergence tradeoffs.

    If you're interested in other schedules, DeepSpeed abstracts the pipeline scheduling from the engine components though, so you could implement other schedules. Things like parameter stashing would just need to be added to the language of the pipeline scheduler.

image

  1. Pipeline communications are implemented using broadcast collectives between groups of size 2. Starting with PyTorch 1.8+, the bundled NCCL version also supports send/recv, and so I am preparing to release a new backend that uses send/recv when available. Other collectives include AllReduce for gradients, though if you enable ZeRO-1 then we can use ReduceScatter instead.

  2. Yes, ZeRO-2 and ZeRO-3 are incompatible with out pipeline parallelism engine. ZeRO-2 partitions gradients that the pipeline engine assumes are intact. Similarly, ZeRO-3 partitions parameters that the pipeline engine assumes are intact. Note that pipeline parallelism already offers some of these advantages by partitioning the model directly, and then ZeRO-1 (with optional offload) can be combined to further partition the optimizer.

@ParamsRaman ParamsRaman changed the title Confusion about pipeline parallelism schedule in DeepSpeed Details about pipeline parallelism implementation in DeepSpeed May 29, 2021
@jeffra
Copy link
Contributor

jeffra commented Aug 19, 2021

Please re-open if not resolved.

@ParamsRaman
Copy link
Author

@jeffra @ShadenSmith
Re-visiting this thread since I noticed few other related github issues filed earlier. I am curious to know about the current status of the pipeline parallelism and compatibility with Zero-2,3, gradient checkpointing etc.

  1. Fix ZeRO 2 + Pipelining #677 - I found this discussion and PR filed about this from earlier this year: wondering if this was merged? Does this work?

  2. Pipeline parallelism and gradient checkpointing (edit: and ZeRO 2!) don’t work together EleutherAI/gpt-neox#62 - I see a number of discussions around PP + Zero2 + Activation Checkpointing. Has this been resolved and merged?

Would be really helpful if you could summarize the latest status on this. Thanks!

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 25, 2021

@ParamsRaman I think #980 this PR says Pipeline parallelism is incompatible with ZeRO2 and 3. This(980) PR was merged later than the one (677) you mentioned.

@ParamsRaman
Copy link
Author

ParamsRaman commented Sep 26, 2021

@hyunwoongko Still a bit confused. Do you mean later this PR was merged => PP + Zero2/3 works now in DeepSpeed? Or is it still open?

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 26, 2021

Nope. PP + ZeRO 2/3 is impossible. PP needs to accumulate gradients, but ZeRO2 needs to chunk gradients. Therefore, they are not compatible. Even if it can be implemented, there is no real performance improvement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants