Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Loss scale already at minimum - Training LlaMA2 7B via HF+deepspeed consistently fails #4017

Closed
scorixear opened this issue Jul 22, 2023 · 27 comments
Assignees
Labels
bug Something isn't working training

Comments

@scorixear
Copy link

scorixear commented Jul 22, 2023

Describe the bug
When training the LLaMA2 7B HF Model with deepspeed on a single-node multi-gpu setup,
the loss_scale gets decreased consistently to 1 (minimum) and exits with error.

Exception: Current loss scale already at minimum - cannot decrease scale anymore. Exiting run.

This appears both with ZeRO Stage2 + CPU Offload and ZeRO Stage3 + CPU Offload.

To Reproduce

  • DeepSpeed with ZeRO Stage 2 + CPU offload.
  • HF Trainer (v4.32.0.dev0)

Expected behavior
Training Completes

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/software/all/staging/PyTorch/1.12.1-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/torch']
torch version .................... 1.12.1
deepspeed install path ........... ['/home/<BLANKED>/LLaMA_Training/.env/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.10.0, unknown, unknown
torch cuda version ............... 11.7
torch hip version ................ None
nvcc version ..................... 11.7
deepspeed wheel compiled w. ...... torch 1.12, cuda 11.7

System info (please complete the following information):

  • OS: Rocky Linux 8
  • GPU count and types: 1 Machine, 4x Nvidia Tesla V100
  • Python version: 3.10.8
  • HuggingFace Transformers: 4.32.0.dev0

Launcher context
Launching with Deepspeed as follows:

srun --jobid 3530272 bash -c "NCCL_DEBUG=INFO deepspeed 
--num_gpus=4 
03_train_llama2.py 
--model_name meta-llama/Llama-2-7b-hf 
--cache_dir ./cache 
--use_fast_tokenizer false 
--model_revision main 
--use_auth_token true 
--hugging_token <BLANKED>
--torch_dtype auto 
--low_cpu_mem_usage false 
--train_file ./input/health_information_systems_epub.md 
--max_train_samples 1000 
--overwrite_cache false 
--validation_split_percentage 5 
--preprocessing_num_workers 1 
--keep_linebreaks true 
--output_dir ./trained/7B 
--overwrite_output_dir false 
--do_train true 
--do_eval false 
--per_device_train_batch_size 1 
--per_device_eval_batch_size 1 
--evaluation_strategy steps 
--eval_steps 100 
--learning_rate 3e-4 
--weight_decay 0.1 
--adam_beta1 0.9 
--adam_beta2 0.95 
--adam_epsilon 1e-8 
--max_grad_norm 1.0 
--num_train_epochs 3 
--lr_scheduler_type cosine 
--warmup_steps 0 
--log_level passive 
--save_strategy steps 
--save_steps 500 
--save_total_limit 1 
--no_cuda false 
--seed 42 
--fp16 true 
--bf16 false 
--half_precision_backend auto 
--local_rank 0 
--ddp_backend nccl 
--deepspeed ./ds_configs/stage2_offload.json 
--optim adamw_torch"

Docker context
No Docker

Additional context
DS Config:

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": { 
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },
    "zero_optimization": {
        "stage": 2,
        "contiguous_gradients": true,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "allgather_bucket_size": 2e8,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        }
    },
    "gradient_clipping": 1.0,
    "steps_per_print": 500,
    "wall_clock_breakdown": false,
    "train_micro_batch_size_per_gpu": 1
}

Slurm Setup

#SBATCH --job-name=deepspeed-llama2-7b-hf        # name
#SBATCH --nodes=1                                                 # nodes
#SBATCH --ntasks-per-node=1                                 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=4
#SBATCH --partition=clara
#SBATCH --mem=256G                                            # 128G was not enough
#SBATCH --gres=gpu:v100:4                                    # number of gpus
#SBATCH --output=logs/%x-%j.out                           # output file name
@scorixear scorixear added bug Something isn't working training labels Jul 22, 2023
@scorixear
Copy link
Author

scorixear commented Jul 23, 2023

Reducing the block size from 1024 to 256 solves this issue in part.
Meaning if the dataset used is small enough, the loss scale gets not recuded fully to 1.

But loss scales are still constantly reduced

@fahadh4ilyas
Copy link

Reducing the bucket size from 1024 to 256 solves this issue in part. Meaning if the dataset used is small enough, the loss scale gets not recuded fully to 1.

But loss scales are still constantly reduced

What do you mean by bucket size? If you mean reduce_bucket_size, the value is not 1024 but 2e8 (200 million).

@wxjiao
Copy link

wxjiao commented Jul 28, 2023

Try to use bf16 as LLaMA-2 was pretrained using bf16. Continuing the training with fp16 will be problematic.

@fahadh4ilyas
Copy link

Try to use bf16 as LLaMA-2 was pretrained using bf16. Continuing the training with fp16 will be problematic.

What if my gpu did not support bf16? Because llama-2 13B did not have this error when training using deepspeed. Only 7B kept getting this error.

@FrankWhh
Copy link

how many gpus you use to train? I use 1 or 2 is ok, while 3 or more will meet this issue. V100 fp16

@fahadh4ilyas
Copy link

how many gpus you use to train? I use 1 or 2 is ok, while 3 or more will meet this issue. V100 fp16

I use 8 x A6000 gpu. Small number of gpu is working? That's weird. I tought bigger number leads to bigger batch and that means better stability.

@FrankWhh
Copy link

you can try, i found this in my training but don't know why

@scorixear
Copy link
Author

scorixear commented Jul 29, 2023

@fahadh4ilyas

Reducing the bucket size from 1024 to 256 solves this issue in part. Meaning if the dataset used is small enough, the loss scale gets not recuded fully to 1.
But loss scales are still constantly reduced

What do you mean by bucket size? If you mean reduce_bucket_size, the value is not 1024 but 2e8 (200 million).

I mean the actual size of the batch in bytes. When using the Huggingface Trainer example code e.g. you can specify "block_size" there (Part of the DataArgument eg. in https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py). In the paper for Llama2 they used 4096, default is normally at 1024 - 3 Epoch training worked with 256
for higher Epochs, I have to go down even lower.

@wxjiao

Try to use bf16 as LLaMA-2 was pretrained using bf16. Continuing the training with fp16 will be problematic.

Sadly not possible, as I am using Nvidia v100 GPUs that do not support bf16, so I have to use fp16

@FrankWhh

how many gpus you use to train? I use 1 or 2 is ok, while 3 or more will meet this issue. V100 fp16

interesting, and you did not run into any OOMs when using zeRO 2 with 2 GPUs?

@FrankWhh
Copy link

32GB is enough for LORA-style finetune

@fahadh4ilyas
Copy link

@fahadh4ilyas

Reducing the bucket size from 1024 to 256 solves this issue in part. Meaning if the dataset used is small enough, the loss scale gets not recuded fully to 1.
But loss scales are still constantly reduced

What do you mean by bucket size? If you mean reduce_bucket_size, the value is not 1024 but 2e8 (200 million).

I mean the actual size of the batch in bytes. When using the Huggingface Trainer e.g. you can specify "bucket_size" there. In the paper for Llama2 they used 4096, default is normally at 1024 - 3 Epoch training worked with 256
for higher Epochs, I have to go down even lower.

Do you mean ddp_bucket_cap_mb parameter? It seems the default value of it is None. how do you know the default is 1024? And does it affect performance?

@scorixear
Copy link
Author

@fahadh4ilyas

Reducing the bucket size from 1024 to 256 solves this issue in part. Meaning if the dataset used is small enough, the loss scale gets not recuded fully to 1.
But loss scales are still constantly reduced

What do you mean by bucket size? If you mean reduce_bucket_size, the value is not 1024 but 2e8 (200 million).

I mean the actual size of the batch in bytes. When using the Huggingface Trainer e.g. you can specify "bucket_size" there. In the paper for Llama2 they used 4096, default is normally at 1024 - 3 Epoch training worked with 256
for higher Epochs, I have to go down even lower.

Do you mean ddp_bucket_cap_mb parameter? It seems the default value of it is None. how do you know the default is 1024? And does it affect performance?

Sorry I meant "block_size". I will edit my message to that now.

@fahadh4ilyas
Copy link

@fahadh4ilyas

Reducing the bucket size from 1024 to 256 solves this issue in part. Meaning if the dataset used is small enough, the loss scale gets not recuded fully to 1.
But loss scales are still constantly reduced

What do you mean by bucket size? If you mean reduce_bucket_size, the value is not 1024 but 2e8 (200 million).

I mean the actual size of the batch in bytes. When using the Huggingface Trainer e.g. you can specify "bucket_size" there. In the paper for Llama2 they used 4096, default is normally at 1024 - 3 Epoch training worked with 256
for higher Epochs, I have to go down even lower.

Do you mean ddp_bucket_cap_mb parameter? It seems the default value of it is None. how do you know the default is 1024? And does it affect performance?

Sorry I meant "block_size". I will edit my message to that now.

Where exactly that param? I'm searching TrainingArguments but it has no "block_size".

@scorixear
Copy link
Author

Sorry I meant "block_size". I will edit my message to that now.

Where exactly that param? I'm searching TrainingArguments but it has no "block_size".

It is not part of the TrainingArguments.
You usually have to implement it yourself - I used the code from https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py

It groups the text in to chunks of the specified size.

@fahadh4ilyas
Copy link

Sorry I meant "block_size". I will edit my message to that now.

Where exactly that param? I'm searching TrainingArguments but it has no "block_size".

It is not part of the TrainingArguments.
You usually have to implement it yourself - I used the code from https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py

It groups the text in to chunks of the specified size.

Oh, that... I can't change that. My data are longer than 256 tokens. Trimming the data will only make the training result worsened.

@fahadh4ilyas
Copy link

you can try, i found this in my training but don't know why

Training with only 2 GPU also not solved the problem for me.

@lucasjinreal
Copy link

Same issue, weired is, this issue actually happened with more than 98% probilities on V100, but.....

@wxjiao
Copy link

wxjiao commented Aug 2, 2023

You may be interested in this HF doc: https://huggingface.co/docs/transformers/v4.15.0/performance

image

@FrankWhh
Copy link

FrankWhh commented Aug 2, 2023

yeah, bf16 is priority to use,while V100 doesn't support

@lucasjinreal
Copy link

Yes, if we have bf16, won't bothering this issue.

@RameshArvind
Copy link

RameshArvind commented Aug 9, 2023

Try to use bf16 as LLaMA-2 was pretrained using bf16. Continuing the training with fp16 will be problematic.

@wxjiao do you some kind of official reference that confirms this? I'm unable to find anything apart from the fact that the config.json for llamav2-7b says that the dtype is float16

@chenfengshijie
Copy link

chenfengshijie commented Aug 17, 2023

I encountered same problems when training VIT,I set scale_window to a relative small value(eg,100),let the loss scale have opportunity to raise when decrease at some batchs.That sloved problem,and you may need to choose an appropriate scale_window.And i am,using V100

@scorixear
Copy link
Author

Try to use bf16 as LLaMA-2 was pretrained using bf16. Continuing the training with fp16 will be problematic.

You may be interested in this HF doc: https://huggingface.co/docs/transformers/v4.15.0/performance
image

yeah, bf16 is priority to use,while V100 doesn't support

Yes, if we have bf16, won't bothering this issue.

I can confirm, running the same script with identical configuration but using Nvidia Tesla A30 GPUs with BF16 enabled solves this issue.
There are no overflows occuring anymore.

@YuFan-Microsoft
Copy link

Any updates here? I also face this problem with my 8*V100 machines.

@scorixear
Copy link
Author

Any updates here? I also face this problem with my 8*V100 machines.

None so far, v100 lead to overflow and huge loss in performance as by per latest evaluation.
For now I think there are three options:

  • not use V100 if Model with BF16 and DeepSpeed required
  • not use DeepSpeed if Model with BF16 and V100 required
  • not use Model with BF16 if V100 and DeepSpeed required

The PullRequest ist also open and had no recent changes.
@YuFan-Microsoft

@wxjiao
Copy link

wxjiao commented Oct 13, 2023

@scorixear

Other options could be:

Both work according to my friends‘ practice.

@loadams
Copy link
Contributor

loadams commented Jan 18, 2024

@scorixear - can you take a look and see if the changes in #4141 help?

@scorixear
Copy link
Author

@scorixear - can you take a look and see if the changes in #4141 help?

@loadams I have rerun my setup with the latest version of deepspeed (0.13.0) and noticed degraded performance of the model due to still present overflow of the loss scale.

However, with the previous version I used (0.9.x) I couldn't train the Llama2 model beyond approximately 400 steps without running into the minimum loss scale issue explained above.

This does not occure anymore as it seems to be stable around a loss scale of 64. So this update definitly did something.
Degraded performance of models previously trained on bf16 and now trained on fp16 is somewhat expected as pointed out in comments here and in the transformers documentation from huggingface.

It is sad that models are therefore hardware depended, but I don't think this is an issue related to deepspeed.

I would leave it at that here, maybe you can decide if this issue should remain open (as it is still an issue in general) or if it can be closed (as it isn't that related to deepspeed anymore).

mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this issue Feb 17, 2024
…ia `load_module_only` (microsoft#4141)

This PR makes some fixes to the case where we want to resume training
from a DeepSpeed ZeRO checkpoint and initialize a new optimizer, while
not using the old optimizer in the checkpoint or relying on its
existence at all.

in this situation, despite passing `load_module_only=True` and
`load_optimizer_states=False` to `load_checkpoint()`, the previous
behavior was that:
- `self._load_zero_checkpoint` would still be called, which attempts to
load from the (in this case, nonexistent) checkpoint files. This PR
stops this function from being called if using `load_module_only=True`
and `load_optimizer_states=False`. Alternatively, calling this function
may be alright if `"load_from_fp32_weights": true` is set in the
DeepSpeed ZeRO config (reference:
https://github.com/microsoft/DeepSpeed/blob/ff7d5275f2aa916cb5f320e0d817154e96f9cdb6/deepspeed/runtime/engine.py#L733)
but this parameter does not seem to be documented in the docs for ZeRO
config dicts.
- in `_load_checkpoint`, the following codeblock: 
```
if self.optimizer is not None and self.fp16_enabled():
    self.optimizer.refresh_fp32_params()
```
results in `self.optimizer.refresh_fp32_params()` being called only if
using FP16. As a result, the FP32 optimizer state is never initialized
from the 16-bit model weights. This PR removes the fp16-specific
condition.


Previously reported in:
EleutherAI/gpt-neox#947
EleutherAI/gpt-neox#843

Should also close:
microsoft#4017

Fixes: microsoft#4944 and microsoft#4017

This caused problems for a freshly-converted LLama checkpoint, which did
not contain optimizer states, when trying to train with this model as
initialization. I have confirmed the following fixes prevent this
behavior.

cc @Quentin-Anthony @zhangir-azerbayev

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

9 participants