Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bf16 is incompatible with pipe parallelism #963

Closed
Life-0-1 opened this issue Jun 2, 2023 · 8 comments · Fixed by #1032
Closed

bf16 is incompatible with pipe parallelism #963

Life-0-1 opened this issue Jun 2, 2023 · 8 comments · Fixed by #1032
Assignees
Labels
bug Something isn't working

Comments

@Life-0-1
Copy link

Life-0-1 commented Jun 2, 2023

I find that bf16 training is only compatible with setting "pipe-parallel-size": 0,, that is pipe parallelism should be disabled.
While setting "pipe-parallel-size": 1,, below error occurs:

Traceback (most recent call last):
  File "train.py", line 27, in <module>
    pretrain(neox_args=neox_args)
  File "gpt-neox/megatron/training.py", line 226, in pretrain
    iteration = train(
  File "gpt-neox/megatron/training.py", line 778, in train
    loss_dict, skipped_iter = train_step(
  File "gpt-neox/megatron/training.py", line 684, in train_step
    reduced_loss = train_step_pipe(
  File "gpt-neox/megatron/training.py", line 734, in train_step_pipe
    loss = model.train_batch(data_iter=data_iterator)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
    self._exec_schedule(sched)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 1374, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 258, in _exec_reduce_grads
    raise NotImplementedError()
NotImplementedError

Below is the relevant deepspeed code, and it says bf16+zero stages 1, 2, 3 are not implemented.
Meanwhile, I also note that this code is different from the latest official deepspeed code.

def _exec_reduce_grads(self):
        self._force_grad_boundary = True
        if self.pipeline_enable_backward_allreduce:
            if self.bfloat16_enabled():
                if self.zero_optimization_stage() == 0:
                    self._bf16_reduce_grads()
                else:
                    assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported"
                    raise NotImplementedError()
            else:
                self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
        self._force_grad_boundary = False

@Life-0-1 Life-0-1 added the bug Something isn't working label Jun 2, 2023
@StellaAthena
Copy link
Member

Thanks for raising this to our attention. Can you add links to the relevant portions of DeepSpeed and DeeperSpeed to make comparison easier?

Also, have you checked if the current main branch of DeepSpeed supports BFloat16 with ZeRO and PP correctly?

@Life-0-1
Copy link
Author

Life-0-1 commented Jun 3, 2023

Thanks for raising this to our attention. Can you add links to the relevant portions of DeepSpeed and DeeperSpeed to make comparison easier?

Also, have you checked if the current main branch of DeepSpeed supports BFloat16 with ZeRO and PP correctly?

There is relevant deepspeed code:
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/pipe/engine.py#L242

I haven't tried the current main branch of deepspeed.
By simply viewing its code, bf16 works with pp and zero stage 1.
But from gpt-neox's deepspeed code, bf16 works with only pp and zero stage 0.

@StellaAthena
Copy link
Member

If that’s the only relevant code fragment, then it looks like we do support PP + BF16 + ZeRO 1, we just don’t have the logic correct. Try changing it to says if self.zero_optimization_stage() <= 1: and see if it works

@Life-0-1
Copy link
Author

Life-0-1 commented Jun 6, 2023

If that’s the only relevant code fragment, then it looks like we do support PP + BF16 + ZeRO 1, we just don’t have the logic correct. Try changing it to says if self.zero_optimization_stage() <= 1: and see if it works

I changed the code to if self.zero_optimization_stage() <= 1, then several other changes should be made to run successfully.

  1. add "data_types": {"grad_accum_dtype"="fp32" }, to config file, because by default, the grad_accum_dtype is the same as model_dtype, which is bf16, and in such combination, zero_optimizer would be used, instead of bf16 optimizer. We should use bf16 optimizer as self._bf16_reduce_grads() is only defined in the bf16 optimizer.
  2. remove the overflow check in megatron/utils.py, as the bf16 optimizer doesn't have .overflow.

@StellaAthena
Copy link
Member

Excellent! Can you open a PR with this patch?

@jiezhangGt
Copy link

Hi,have you ever met the error "AttributeError: 'BF16_Optimizer' object has no attribute 'fp16_groups'"

@StellaAthena
Copy link
Member

@Life-0-1 @dashstander has this been addressed?

@dashstander dashstander linked a pull request Sep 15, 2023 that will close this issue
@dashstander
Copy link
Contributor

@Life-0-1 @StellaAthena there's a PR now! Most of the changes were made by DeepSpeed, though they still have an outstanding issue. If you have pp > 0 and zero1 then you, for now, need to set the grad_accum_dtype to fp32

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants