Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning loss explode when not loading deepspeed zero optimal states #843

Open
sxthunder opened this issue Mar 19, 2023 · 9 comments
Open
Labels
bug Something isn't working

Comments

@sxthunder
Copy link

sxthunder commented Mar 19, 2023

Describe the bug
I have trained a 1.3B model on 64 A100 80G Gpus, I export the saved checkpoints except the deepspeed zero-optimal states, the exported ckpts structure is same as your opensource 20B ckpts.
Then I want to fine-tune on the model on 8 Gpus, only adding {"finetune":true} in config yaml.

when I run the program, Model parameters successful loaded:

[2023-03-19 12:18:46,600] [INFO] [engine.py:1551:_load_checkpoint] rank: 5 loading checkpoint: /mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/mp_rank_00_model_states.pt
[2023-03-19 12:18:46,600] [INFO] [engine.py:1551:_load_checkpoint] rank: 7 loading checkpoint: /mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/mp_rank_00_model_states.pt
[2023-03-19 12:18:46,600] [INFO] [engine.py:1551:_load_checkpoint] rank: 4 loading checkpoint: /mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/mp_rank_00_model_states.pt
[2023-03-19 12:18:46,600] [INFO] [engine.py:1551:_load_checkpoint] rank: 0 loading checkpoint: /mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/mp_rank_00_model_states.pt
[2023-03-19 12:18:46,600] [INFO] [engine.py:1551:_load_checkpoint] rank: 1 loading checkpoint: /mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/mp_rank_00_model_states.pt
[2023-03-19 12:18:46,600] [INFO] [engine.py:1551:_load_checkpoint] rank: 2 loading checkpoint: /mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/mp_rank_00_model_states.pt
[2023-03-19 12:18:46,600] [INFO] [engine.py:1551:_load_checkpoint] rank: 3 loading checkpoint: /mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/mp_rank_00_model_states.pt
[2023-03-19 12:18:46,601] [INFO] [engine.py:1551:_load_checkpoint] rank: 6 loading checkpoint: /mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/mp_rank_00_model_states.pt
[2023-03-19 12:18:47,239] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=0 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_00-model_00-model_states.pt
[2023-03-19 12:18:47,419] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=2 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_02-model_00-model_states.pt
[2023-03-19 12:18:47,570] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=3 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_03-model_00-model_states.pt
[2023-03-19 12:18:47,761] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=4 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_04-model_00-model_states.pt
[2023-03-19 12:18:47,942] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=5 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_05-model_00-model_states.pt
[2023-03-19 12:18:48,135] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=6 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_06-model_00-model_states.pt
[2023-03-19 12:18:48,317] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=7 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_07-model_00-model_states.pt
[2023-03-19 12:18:48,496] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=8 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_08-model_00-model_states.pt
[2023-03-19 12:18:48,667] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=9 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_09-model_00-model_states.pt
[2023-03-19 12:18:48,867] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=10 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_10-model_00-model_states.pt
[2023-03-19 12:18:49,057] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=11 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_11-model_00-model_states.pt
[2023-03-19 12:18:49,249] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=12 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_12-model_00-model_states.pt
[2023-03-19 12:18:49,415] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=13 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_13-model_00-model_states.pt
[2023-03-19 12:18:49,621] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=14 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_14-model_00-model_states.pt
[2023-03-19 12:18:49,795] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=15 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_15-model_00-model_states.pt
[2023-03-19 12:18:49,972] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=16 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_16-model_00-model_states.pt
[2023-03-19 12:18:50,197] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=17 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_17-model_00-model_states.pt
[2023-03-19 12:18:50,384] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=18 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_18-model_00-model_states.pt
[2023-03-19 12:18:50,571] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=19 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_19-model_00-model_states.pt
[2023-03-19 12:18:50,737] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=20 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_20-model_00-model_states.pt
[2023-03-19 12:18:50,922] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=21 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_21-model_00-model_states.pt
[2023-03-19 12:18:51,082] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=22 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_22-model_00-model_states.pt
[2023-03-19 12:18:51,268] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=23 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_23-model_00-model_states.pt
[2023-03-19 12:18:51,549] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=24 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_24-model_00-model_states.pt
[2023-03-19 12:18:51,708] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=25 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_25-model_00-model_states.pt
[2023-03-19 12:18:51,713] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=27 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_27-model_00-model_states.pt
[2023-03-19 12:18:52,336] [INFO] [module.py:576:load_state_dir] RANK=0 Loaded layer=28 file=/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/layer_28-model_00-model_states.pt

But after that, it wants to load zero-optimal parameters, obviously the parameters are mssing:

[2023-03-19 12:18:52,428] [WARNING] [engine.py:1656:_get_all_zero_checkpoints] The following zero checkpoints paths are missing: ['/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_0_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_1_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_2_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_3_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_4_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_5_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_6_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_7_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_8_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_9_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_10_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_11_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_12_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_13_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_14_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_15_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_16_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_17_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_18_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_19_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_20_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_21_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_22_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_23_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_24_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_25_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_26_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_27_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_28_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_29_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_30_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_31_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_32_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_33_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_34_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_35_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_36_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_37_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_38_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_39_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_40_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_41_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_42_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_43_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_44_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_45_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_46_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_47_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_48_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_49_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_50_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_51_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_52_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_53_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_54_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_55_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_56_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_57_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_58_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_59_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_60_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_61_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_62_mp_rank_00_optim_states.pt', '/mnt/resources/codegpt-sft-v1/checkpoint-neox-1.3b/global_step80000/zero_pp_rank_63_mp_rank_00_optim_states.pt']

Then Model starts training, but the loss scale is unnoral, you can find the first 10 steps are skipped

[2023-03-19 12:19:18,292] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 4294967296
[2023-03-19 12:19:18,292] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 4294967296
[2023-03-19 12:19:18,292] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 4294967296
[2023-03-19 12:19:18,292] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 4294967296
[2023-03-19 12:19:18,292] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 4294967296
[2023-03-19 12:19:18,293] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 4294967296
[2023-03-19 12:19:18,293] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 4294967296
[2023-03-19 12:19:18,293] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 4294967296
[2023-03-19 12:19:23,759] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 2147483648.0
[2023-03-19 12:19:23,759] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 2147483648.0
[2023-03-19 12:19:23,759] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 2147483648.0
[2023-03-19 12:19:23,759] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 2147483648.0
[2023-03-19 12:19:23,759] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 2147483648.0
[2023-03-19 12:19:23,759] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 2147483648.0
[2023-03-19 12:19:23,759] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 2147483648.0
[2023-03-19 12:19:23,759] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 4294967296, reducing to 2147483648.0
[2023-03-19 12:19:28,911] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 2147483648.0, reducing to 1073741824.0
[2023-03-19 12:19:28,911] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 2147483648.0, reducing to 1073741824.0
[2023-03-19 12:19:28,911] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 2147483648.0, reducing to 1073741824.0
[2023-03-19 12:19:28,911] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 2147483648.0, reducing to 1073741824.0
[2023-03-19 12:19:28,911] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 2147483648.0, reducing to 1073741824.0
[2023-03-19 12:19:28,911] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 2147483648.0, reducing to 1073741824.0
[2023-03-19 12:19:28,911] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 2147483648.0, reducing to 1073741824.0
[2023-03-19 12:19:28,911] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 2147483648.0, reducing to 1073741824.0
[2023-03-19 12:19:34,127] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 1073741824.0, reducing to 536870912.0
[2023-03-19 12:19:34,127] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 1073741824.0, reducing to 536870912.0
[2023-03-19 12:19:34,127] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 1073741824.0, reducing to 536870912.0
[2023-03-19 12:19:34,127] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 1073741824.0, reducing to 536870912.0
[2023-03-19 12:19:34,127] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 1073741824.0, reducing to 536870912.0
[2023-03-19 12:19:34,127] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 1073741824.0, reducing to 536870912.0
[2023-03-19 12:19:34,127] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 1073741824.0, reducing to 536870912.0
[2023-03-19 12:19:34,127] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 1073741824.0, reducing to 536870912.0
[2023-03-19 12:19:39,412] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 536870912.0, reducing to 268435456.0
[2023-03-19 12:19:39,412] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 536870912.0, reducing to 268435456.0
[2023-03-19 12:19:39,412] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 536870912.0, reducing to 268435456.0
[2023-03-19 12:19:39,412] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 536870912.0, reducing to 268435456.0
[2023-03-19 12:19:39,412] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 536870912.0, reducing to 268435456.0
[2023-03-19 12:19:39,412] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 536870912.0, reducing to 268435456.0
[2023-03-19 12:19:39,412] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 536870912.0, reducing to 268435456.0
[2023-03-19 12:19:39,412] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 536870912.0, reducing to 268435456.0
[2023-03-19 12:19:44,737] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 268435456.0, reducing to 134217728.0
[2023-03-19 12:19:44,737] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 268435456.0, reducing to 134217728.0
[2023-03-19 12:19:44,737] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 268435456.0, reducing to 134217728.0
[2023-03-19 12:19:44,737] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 268435456.0, reducing to 134217728.0
[2023-03-19 12:19:44,737] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 268435456.0, reducing to 134217728.0
[2023-03-19 12:19:44,737] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 268435456.0, reducing to 134217728.0
[2023-03-19 12:19:44,737] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 268435456.0, reducing to 134217728.0
[2023-03-19 12:19:44,737] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 268435456.0, reducing to 134217728.0
[2023-03-19 12:19:50,055] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 134217728.0, reducing to 67108864.0
[2023-03-19 12:19:50,055] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 134217728.0, reducing to 67108864.0
[2023-03-19 12:19:50,055] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 134217728.0, reducing to 67108864.0
[2023-03-19 12:19:50,055] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 134217728.0, reducing to 67108864.0
[2023-03-19 12:19:50,055] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 134217728.0, reducing to 67108864.0
[2023-03-19 12:19:50,055] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 134217728.0, reducing to 67108864.0
[2023-03-19 12:19:50,055] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 134217728.0, reducing to 67108864.0
[2023-03-19 12:19:50,055] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 134217728.0, reducing to 67108864.0
[2023-03-19 12:19:55,372] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 67108864.0, reducing to 33554432.0
[2023-03-19 12:19:55,372] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 67108864.0, reducing to 33554432.0
[2023-03-19 12:19:55,372] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 67108864.0, reducing to 33554432.0
[2023-03-19 12:19:55,372] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 67108864.0, reducing to 33554432.0
[2023-03-19 12:19:55,372] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 67108864.0, reducing to 33554432.0
[2023-03-19 12:19:55,372] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 67108864.0, reducing to 33554432.0
[2023-03-19 12:19:55,372] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 67108864.0, reducing to 33554432.0
[2023-03-19 12:19:55,372] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 67108864.0, reducing to 33554432.0
[2023-03-19 12:20:00,688] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 33554432.0, reducing to 16777216.0
[2023-03-19 12:20:00,688] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 33554432.0, reducing to 16777216.0
[2023-03-19 12:20:00,688] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 33554432.0, reducing to 16777216.0
[2023-03-19 12:20:00,688] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 33554432.0, reducing to 16777216.0
[2023-03-19 12:20:00,688] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 33554432.0, reducing to 16777216.0
[2023-03-19 12:20:00,688] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 33554432.0, reducing to 16777216.0
[2023-03-19 12:20:00,688] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 33554432.0, reducing to 16777216.0
[2023-03-19 12:20:00,688] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 33554432.0, reducing to 16777216.0
[2023-03-19 12:20:06,011] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 16777216.0, reducing to 8388608.0
[2023-03-19 12:20:06,012] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 16777216.0, reducing to 8388608.0
[2023-03-19 12:20:06,012] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 16777216.0, reducing to 8388608.0
[2023-03-19 12:20:06,012] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 16777216.0, reducing to 8388608.0
[2023-03-19 12:20:06,012] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 16777216.0, reducing to 8388608.0
[2023-03-19 12:20:06,012] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 16777216.0, reducing to 8388608.0
[2023-03-19 12:20:06,012] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 16777216.0, reducing to 8388608.0
[2023-03-19 12:20:06,012] [INFO] [stage1.py:697:step] [deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss scale: 16777216.0, reducing to 8388608.0
 samples/sec: 37.204 | iteration       10/   10000 | elapsed time per iteration (ms): 6881.1 | learning rate: 0.000E+00 | approx flops per GPU: 150.1TFLOPS | loss scale: 8388608.0 | number of skipped iterations:  10 | number of nan iterations:   0 |
after 10 iterations memory (MB) | allocated: 11998.7001953125 | max allocated: 26177.17822265625 | reserved: 32944.0 | max reserved: 32944.0
time (ms)

10-20 Steps loss is 6.+, after 20 steps, the loss scaled to 10+

 samples/sec: 47.912 | iteration       20/   10000 | elapsed time per iteration (ms): 5343.1 | learning rate: 2.771E-08 | approx flops per GPU: 193.3TFLOPS | lm_loss: 6.614222E+00 | loss scale: 32768.0 | number of skipped iterations:   8 | number of nan iterations:   0 |
time (ms)
 samples/sec: 47.385 | iteration       30/   10000 | elapsed time per iteration (ms): 5402.5 | learning rate: 1.663E-07 | approx flops per GPU: 191.2TFLOPS | lm_loss: 1.162083E+01 | loss scale: 32768.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 47.342 | iteration       40/   10000 | elapsed time per iteration (ms): 5407.4 | learning rate: 3.049E-07 | approx flops per GPU: 191.0TFLOPS | lm_loss: 1.084407E+01 | loss scale: 32768.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 47.399 | iteration       50/   10000 | elapsed time per iteration (ms): 5400.9 | learning rate: 4.434E-07 | approx flops per GPU: 191.2TFLOPS | lm_loss: 1.001282E+01 | loss scale: 32768.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 47.404 | iteration       60/   10000 | elapsed time per iteration (ms): 5400.4 | learning rate: 5.820E-07 | approx flops per GPU: 191.3TFLOPS | lm_loss: 9.492722E+00 | loss scale: 32768.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 47.411 | iteration       70/   10000 | elapsed time per iteration (ms): 5399.6 | learning rate: 7.206E-07 | approx flops per GPU: 191.3TFLOPS | lm_loss: 9.127460E+00 | loss scale: 32768.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 47.363 | iteration       80/   10000 | elapsed time per iteration (ms): 5405.0 | learning rate: 8.591E-07 | approx flops per GPU: 191.1TFLOPS | lm_loss: 8.873161E+00 | loss scale: 32768.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 47.396 | iteration       90/   10000 | elapsed time per iteration (ms): 5401.3 | learning rate: 9.700E-07 | approx flops per GPU: 191.2TFLOPS | lm_loss: 8.657443E+00 | loss scale: 32768.0 | number of skipped iterations:   0 | number of nan iterations:   0 |

After that, the loss decreases like pre-training, I test the finetune model which is clearly unnormal.

The I re-finetune the model on 64 Gpus with ds zero-optimal states, everything goes well:

 samples/sec: 95.657 | iteration       10/     200 | elapsed time per iteration (ms): 5352.5 | learning rate: 9.685E-07 | approx flops per GPU: 48.2TFLOPS | lm_loss: 1.250598E+00 | loss scale: 4096.0 | number of skipped iterations:   3 | number of nan iterations:   0 |
after 10 iterations memory (MB) | allocated: 7413.9873046875 | max allocated: 18502.23876953125 | reserved: 21554.0 | max reserved: 21554.0
time (ms)
 samples/sec: 99.478 | iteration       20/     200 | elapsed time per iteration (ms): 5146.9 | learning rate: 9.566E-07 | approx flops per GPU: 50.2TFLOPS | lm_loss: 1.154644E+00 | loss scale: 4096.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 100.073 | iteration       30/     200 | elapsed time per iteration (ms): 5116.3 | learning rate: 9.331E-07 | approx flops per GPU: 50.5TFLOPS | lm_loss: 1.105953E+00 | loss scale: 4096.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 99.810 | iteration       40/     200 | elapsed time per iteration (ms): 5129.8 | learning rate: 8.985E-07 | approx flops per GPU: 50.3TFLOPS | lm_loss: 1.073559E+00 | loss scale: 4096.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 98.904 | iteration       50/     200 | elapsed time per iteration (ms): 5176.7 | learning rate: 8.538E-07 | approx flops per GPU: 49.9TFLOPS | lm_loss: 1.053401E+00 | loss scale: 4096.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 98.668 | iteration       60/     200 | elapsed time per iteration (ms): 5189.1 | learning rate: 8.000E-07 | approx flops per GPU: 49.8TFLOPS | lm_loss: 1.043200E+00 | loss scale: 4096.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 99.999 | iteration       70/     200 | elapsed time per iteration (ms): 5120.0 | learning rate: 7.384E-07 | approx flops per GPU: 50.4TFLOPS | lm_loss: 1.031303E+00 | loss scale: 4096.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 98.737 | iteration       80/     200 | elapsed time per iteration (ms): 5185.5 | learning rate: 6.706E-07 | approx flops per GPU: 49.8TFLOPS | lm_loss: 1.022049E+00 | loss scale: 4096.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
time (ms)
 samples/sec: 98.484 | iteration       90/     200 | elapsed time per iteration (ms): 5198.8 | learning rate: 5.982E-07 | approx flops per GPU: 49.7TFLOPS | lm_loss: 1.015269E+00 | loss scale: 4096.0 | number of skipped iterations:   0 | number of nan iterations:   0 |

Is this a Bug or I missed some process steps on the pretrain saved ckpts?

@sxthunder sxthunder added the bug Something isn't working label Mar 19, 2023
@sxthunder
Copy link
Author

@StellaAthena

@sxthunder sxthunder changed the title Finetuning loss explode when deepspeed zero optimal states Finetuning loss explode when not loading deepspeed zero optimal states Mar 20, 2023
@guozhiyao
Copy link

Hi, have you solved the problem? I meet the same problem.

@sxthunder
Copy link
Author

Hi, have you solved the problem? I meet the same problem.

If using megatron, seems like you must load zero optimizer. Keep MP size same, gpu num could be different.
Or transfer model to HF, but the training speed is low

@guozhiyao
Copy link

Hi, have you solved the problem? I meet the same problem.

If using megatron, seems like you must load zero optimizer. Keep MP size same, gpu num could be different. Or transfer model to HF, but the training speed is low

I keep the mp size the same, while the gpu num was changed. But the loss scaled to 10+, even I set the load_optimizer_states=False,load_lr_scheduler_states=False and load_module_only=True.

@StellaAthena
Copy link
Member

Can you explain why this isn’t the desired behavior?

@sxthunder
Copy link
Author

I transformed the parameters to huggingface without ds-zero-states, it works well. Why does gpt-neox must load zero-states?

@sxthunder
Copy link
Author

Can you explain why this isn’t the desired behavior?

Does GPT-Neox 2.0 not support finetune model using different gpu nums? I pretrain 6B model using GPT-Neox 2.0 with 256 GPUS, then finetuning using 32 GPUS. The logs shows model states and zero optimizer are successfully loaded, But Loss explose after second steps.

@StellaAthena
Copy link
Member

Can you explain why this isn’t the desired behavior?

Does GPT-Neox 2.0 not support finetune model using different gpu nums? I pretrain 6B model using GPT-Neox 2.0 with 256 GPUS, then finetuning using 32 GPUS. The logs shows model states and zero optimizer are successfully loaded, But Loss explose after second steps.

We currently have a PR working its way through that will fix this problem. We hope to have it merged later this week. #836

@sxthunder
Copy link
Author

Can you explain why this isn’t the desired behavior?

Does GPT-Neox 2.0 not support finetune model using different gpu nums? I pretrain 6B model using GPT-Neox 2.0 with 256 GPUS, then finetuning using 32 GPUS. The logs shows model states and zero optimizer are successfully loaded, But Loss explose after second steps.

@StellaAthena Thanks, we are looking forward to be the first user

github-merge-queue bot pushed a commit to microsoft/DeepSpeed that referenced this issue Jan 18, 2024
…ia `load_module_only` (#4141)

This PR makes some fixes to the case where we want to resume training
from a DeepSpeed ZeRO checkpoint and initialize a new optimizer, while
not using the old optimizer in the checkpoint or relying on its
existence at all.

in this situation, despite passing `load_module_only=True` and
`load_optimizer_states=False` to `load_checkpoint()`, the previous
behavior was that:
- `self._load_zero_checkpoint` would still be called, which attempts to
load from the (in this case, nonexistent) checkpoint files. This PR
stops this function from being called if using `load_module_only=True`
and `load_optimizer_states=False`. Alternatively, calling this function
may be alright if `"load_from_fp32_weights": true` is set in the
DeepSpeed ZeRO config (reference:
https://github.com/microsoft/DeepSpeed/blob/ff7d5275f2aa916cb5f320e0d817154e96f9cdb6/deepspeed/runtime/engine.py#L733)
but this parameter does not seem to be documented in the docs for ZeRO
config dicts.
- in `_load_checkpoint`, the following codeblock: 
```
if self.optimizer is not None and self.fp16_enabled():
    self.optimizer.refresh_fp32_params()
```
results in `self.optimizer.refresh_fp32_params()` being called only if
using FP16. As a result, the FP32 optimizer state is never initialized
from the 16-bit model weights. This PR removes the fp16-specific
condition.


Previously reported in:
EleutherAI/gpt-neox#947
EleutherAI/gpt-neox#843

Should also close:
#4017

Fixes: #4944 and #4017

This caused problems for a freshly-converted LLama checkpoint, which did
not contain optimizer states, when trying to train with this model as
initialization. I have confirmed the following fixes prevent this
behavior.

cc @Quentin-Anthony @zhangir-azerbayev

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this issue Feb 17, 2024
…ia `load_module_only` (microsoft#4141)

This PR makes some fixes to the case where we want to resume training
from a DeepSpeed ZeRO checkpoint and initialize a new optimizer, while
not using the old optimizer in the checkpoint or relying on its
existence at all.

in this situation, despite passing `load_module_only=True` and
`load_optimizer_states=False` to `load_checkpoint()`, the previous
behavior was that:
- `self._load_zero_checkpoint` would still be called, which attempts to
load from the (in this case, nonexistent) checkpoint files. This PR
stops this function from being called if using `load_module_only=True`
and `load_optimizer_states=False`. Alternatively, calling this function
may be alright if `"load_from_fp32_weights": true` is set in the
DeepSpeed ZeRO config (reference:
https://github.com/microsoft/DeepSpeed/blob/ff7d5275f2aa916cb5f320e0d817154e96f9cdb6/deepspeed/runtime/engine.py#L733)
but this parameter does not seem to be documented in the docs for ZeRO
config dicts.
- in `_load_checkpoint`, the following codeblock: 
```
if self.optimizer is not None and self.fp16_enabled():
    self.optimizer.refresh_fp32_params()
```
results in `self.optimizer.refresh_fp32_params()` being called only if
using FP16. As a result, the FP32 optimizer state is never initialized
from the 16-bit model weights. This PR removes the fp16-specific
condition.


Previously reported in:
EleutherAI/gpt-neox#947
EleutherAI/gpt-neox#843

Should also close:
microsoft#4017

Fixes: microsoft#4944 and microsoft#4017

This caused problems for a freshly-converted LLama checkpoint, which did
not contain optimizer states, when trying to train with this model as
initialization. I have confirmed the following fixes prevent this
behavior.

cc @Quentin-Anthony @zhangir-azerbayev

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants