Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running with bf16 error #939

Closed
Life-0-1 opened this issue May 17, 2023 · 9 comments · Fixed by #941
Closed

Running with bf16 error #939

Life-0-1 opened this issue May 17, 2023 · 9 comments · Fixed by #941
Assignees
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@Life-0-1
Copy link

Describe the bug
I tested bf16 training with various configurations, including the one given at configs/bf16_125M.yml, but all of them failed.
I list the configurations and corresponding errors bellow.
Code is running on 8 A100 GPUS and the latest gpt-neox repo is used.

Config 1, which is the same as configs/bf16_125M.yml

"zero_optimization": {
    "stage": 0,
    "allgather_partitions": True,
    "allgather_bucket_size": 500000000,
    "overlap_comm": True,
    "reduce_scatter": True,
    "reduce_bucket_size": 500000000,
    "contiguous_gradients": True,
  },

"fp16": {
     "enabled": true,
     "type": "bfloat16", # set bf16 as precision
     "loss_scale": 0,
     "loss_scale_window": 1000,
     "hysteresis": 2,
     "min_loss_scale": 1
   },
"fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32

error

Traceback (most recent call last):
  File "train.py", line 27, in <module>
    pretrain(neox_args=neox_args)
  File "gpt-neox/megatron/training.py", line 226, in pretrain
    iteration = train(
  File "gpt-neox/megatron/training.py", line 778, in train
    loss_dict, skipped_iter = train_step(
  File "gpt-neox/megatron/training.py", line 684, in train_step
    reduced_loss = train_step_pipe(
  File "gpt-neox/megatron/training.py", line 734, in train_step_pipe
    loss = model.train_batch(data_iter=data_iterator)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
    self._exec_schedule(sched)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 1374, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 658, in _exec_forward_pass
    outputs = super().forward(inputs)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 11, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1844, in forward
    loss = self.module(*inputs, **kwargs)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/module.py", line 359, in forward
    x = self.activation_checkpoint_func(
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 754, in checkpoint
    CheckpointFunction.apply(function, all_outputs, *args)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 588, in forward
    outputs = run_function(*inputs_cuda)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/module.py", line 337, in exec_func
    inputs = layer(inputs)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "gpt-neox/megatron/model/transformer.py", line 912, in forward
    return super().forward(hidden_states, attention_mask), attention_mask
  File "gpt-neox/megatron/model/transformer.py", line 874, in forward
    attention_output, attention_bias = self.attention(
  File "anaconda3/envs/neox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "gpt-neox/megatron/model/transformer.py", line 705, in forward
    context_layer = self.attention(
  File "gpt-neox/megatron/model/transformer.py", line 502, in attention
    context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
RuntimeError: expected scalar type Half but found Float

Config 2

"zero_optimization": {
    "stage": 0,
    "allgather_partitions": True,
    "allgather_bucket_size": 500000000,
    "overlap_comm": True,
    "reduce_scatter": True,
    "reduce_bucket_size": 500000000,
    "contiguous_gradients": True,
  },

"bf16": {
    "enabled": true,
  },

errors

 File "train.py", line 27, in <module>
    pretrain(neox_args=neox_args)
  File "gpt-neox/megatron/training.py", line 226, in pretrain
    iteration = train(
  File "gpt-neox/megatron/training.py", line 789, in train
    overflow_monitor.check(skipped_iter)  # check for repeated overflow
  File "gpt-neox/megatron/utils.py", line 356, in check
    self.optimizer.overflow
AttributeError: 'BF16_Optimizer' object has no attribute 'overflow'
@Life-0-1 Life-0-1 added the bug Something isn't working label May 17, 2023
@Life-0-1
Copy link
Author

Also, if running with zero stage 1, there is a NotImplementedError:

File "train.py", line 27, in <module>
    pretrain(neox_args=neox_args)
  File "gpt-neox/megatron/training.py", line 226, in pretrain
    iteration = train(
  File "gpt-neox/megatron/training.py", line 778, in train
    loss_dict, skipped_iter = train_step(
  File "gpt-neox/megatron/training.py", line 684, in train_step
    reduced_loss = train_step_pipe(
  File "gpt-neox/megatron/training.py", line 734, in train_step_pipe
    loss = model.train_batch(data_iter=data_iterator)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
    self._exec_schedule(sched)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 1374, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 258, in _exec_reduce_grads
    raise NotImplementedError()

These deepspeed code snippets are different from latest official deepspeed.

@Life-0-1
Copy link
Author

Updata
I find that bfloat16 is only ok with pipe-parallel-size=0 , maybe we should use deepspeed pp?
The below config works and is compatible with arguments.py

"pipe-parallel-size": 0,
"bf16": {
    "enabled": true,
  },
  "fp16": {
    "enabled": false,
    "type": "bfloat16",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1,
    "initial_scale_power": 12
  },

@StellaAthena StellaAthena added the good first issue Good for newcomers label May 17, 2023
@StellaAthena
Copy link
Member

StellaAthena commented May 17, 2023

This PR changed how we process bfloat16 configuration #787, but it doesn't look like the demo file was updated. Apologies for the oversight.

Can you try deleting the fp16 and bf16 argument and replacing it with {"precision": "bfloat16"}

@dashstander
Copy link
Contributor

Yeah, my apologies @Life-0-1 -- putting it in the fp16 dict is deprecated by DeepSpeed and we decided that adding yet another bf16 config would be confusing. I'll update the demo file

@dashstander
Copy link
Contributor

@Life-0-1 does this fix solve your issue?

@Life-0-1
Copy link
Author

@Life-0-1 does this fix solve your issue?

Yes, this enables bfloat16 training with a more clear config.
But it seems bfloat16 doesn't go well with megatron pipe parallelism? I have to set pipe-parallel-size=0 .
Is it a bug or normal? Can you give me some references about this?
In BLOOM, they used deepspeed pipe parallel, and it seems compatible with bfloat16? I'm not sure with this, please correct me if I'm wrong.

@dashstander
Copy link
Contributor

@Life-0-1 what error do you get when you use pipe_parallel_size > 0?

@Life-0-1
Copy link
Author

Life-0-1 commented May 19, 2023

@Life-0-1 what error do you get when you use pipe_parallel_size > 0?

This error occurs with "pipe-parallel-size": 1 and zero stage 1.

Traceback (most recent call last):
  File "train.py", line 27, in <module>
    pretrain(neox_args=neox_args)
  File "gpt-neox/megatron/training.py", line 226, in pretrain
    iteration = train(
  File "gpt-neox/megatron/training.py", line 778, in train
    loss_dict, skipped_iter = train_step(
  File "gpt-neox/megatron/training.py", line 684, in train_step
    reduced_loss = train_step_pipe(
  File "gpt-neox/megatron/training.py", line 734, in train_step_pipe
    loss = model.train_batch(data_iter=data_iterator)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
    self._exec_schedule(sched)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 1374, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 258, in _exec_reduce_grads
    raise NotImplementedError()
NotImplementedError

Below is the relevant deepspeed code, and the logic behind the code is bf16 only works with zero optimization stage 0.

def _exec_reduce_grads(self):
        self._force_grad_boundary = True
        if self.pipeline_enable_backward_allreduce:
            if self.bfloat16_enabled():
                if self.zero_optimization_stage() == 0:
                    self._bf16_reduce_grads()
                else:
                    assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported"
                    raise NotImplementedError()
            else:
                self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
        self._force_grad_boundary = False

After setting zero stage to 0, below error occurs:

Traceback (most recent call last):
  File "train.py", line 27, in <module>
    pretrain(neox_args=neox_args)
  File "gpt-neox/megatron/training.py", line 226, in pretrain
    iteration = train(
  File "gpt-neox/megatron/training.py", line 789, in train
    overflow_monitor.check(skipped_iter)  # check for repeated overflow
  File "gpt-neox/megatron/utils.py", line 356, in check
    self.optimizer.overflow
AttributeError: 'BF16_Optimizer' object has no attribute 'overflow'

@panda1681
Copy link

@Life-0-1 what error do you get when you use pipe_parallel_size > 0?

This error occurs with "pipe-parallel-size": 1 and zero stage 1.

Traceback (most recent call last):
  File "train.py", line 27, in <module>
    pretrain(neox_args=neox_args)
  File "gpt-neox/megatron/training.py", line 226, in pretrain
    iteration = train(
  File "gpt-neox/megatron/training.py", line 778, in train
    loss_dict, skipped_iter = train_step(
  File "gpt-neox/megatron/training.py", line 684, in train_step
    reduced_loss = train_step_pipe(
  File "gpt-neox/megatron/training.py", line 734, in train_step_pipe
    loss = model.train_batch(data_iter=data_iterator)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
    self._exec_schedule(sched)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 1374, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "anaconda3/envs/neox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 258, in _exec_reduce_grads
    raise NotImplementedError()
NotImplementedError

Below is the relevant deepspeed code, and the logic behind the code is bf16 only works with zero optimization stage 0.

def _exec_reduce_grads(self):
        self._force_grad_boundary = True
        if self.pipeline_enable_backward_allreduce:
            if self.bfloat16_enabled():
                if self.zero_optimization_stage() == 0:
                    self._bf16_reduce_grads()
                else:
                    assert self.zero_optimization_stage() == 1, "only bf16 + z1 are supported"
                    raise NotImplementedError()
            else:
                self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
        self._force_grad_boundary = False

After setting zero stage to 0, below error occurs:

Traceback (most recent call last):
  File "train.py", line 27, in <module>
    pretrain(neox_args=neox_args)
  File "gpt-neox/megatron/training.py", line 226, in pretrain
    iteration = train(
  File "gpt-neox/megatron/training.py", line 789, in train
    overflow_monitor.check(skipped_iter)  # check for repeated overflow
  File "gpt-neox/megatron/utils.py", line 356, in check
    self.optimizer.overflow
AttributeError: 'BF16_Optimizer' object has no attribute 'overflow'

encountered the same problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants