-
Notifications
You must be signed in to change notification settings - Fork 982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bf16 is incompatible with pipe parallelism #963
Comments
Thanks for raising this to our attention. Can you add links to the relevant portions of DeepSpeed and DeeperSpeed to make comparison easier? Also, have you checked if the current |
There is relevant deepspeed code: I haven't tried the current main branch of deepspeed. |
If that’s the only relevant code fragment, then it looks like we do support PP + BF16 + ZeRO 1, we just don’t have the logic correct. Try changing it to says |
I changed the code to
|
Excellent! Can you open a PR with this patch? |
Hi,have you ever met the error "AttributeError: 'BF16_Optimizer' object has no attribute 'fp16_groups'" |
@Life-0-1 @dashstander has this been addressed? |
@Life-0-1 @StellaAthena there's a PR now! Most of the changes were made by DeepSpeed, though they still have an outstanding issue. If you have pp > 0 and zero1 then you, for now, need to set the grad_accum_dtype to fp32 |
I find that bf16 training is only compatible with setting
"pipe-parallel-size": 0,
, that is pipe parallelism should be disabled.While setting
"pipe-parallel-size": 1,
, below error occurs:Below is the relevant deepspeed code, and it says
bf16+zero stages 1, 2, 3
are not implemented.Meanwhile, I also note that this code is different from the latest official deepspeed code.
The text was updated successfully, but these errors were encountered: