Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misindexing when converting llama weights to gpt-neox format #971

Closed
CRSilkworth opened this issue Jun 9, 2023 · 13 comments
Closed

misindexing when converting llama weights to gpt-neox format #971

CRSilkworth opened this issue Jun 9, 2023 · 13 comments
Labels
bug Something isn't working

Comments

@CRSilkworth
Copy link

CRSilkworth commented Jun 9, 2023

Describe the bug
After running the convert_raw_llama_weights_to_neox.py with --pipeline_parallel the checkpoint are missing the 2nd and 3rd layers (i.e.):
layer_02-model_-model_states.pt
layer_03-model_
-model_states.pt

The first layer files after the layer_00-model_* are the layer_04-model_* files. But the other gpt-neox checkpoints have the layer_02 and layer_03 files, and is what the GPTModelPipe is expecting.

This causes error when loading model for training / inference since those weights are not found.

To Reproduce

  1. Run with convert_raw_llama_weights_to_neox.py with pipeline_parallel:
python tools/convert_raw_llama_weights_to_neox.py --input_dir </path/to/py_lamma_data> --model_size 7B --output_dir </path/to/output_checkpoints> --num_output_shards <mp> --pipeline_parallel
  1. Run finetuning or generate text with load attribute pointing to newly converted checkpoints:
python ./deep.py train.py configs/llama/7B.yml configs/cluster_config.yml
  1. You get this error:
Traceback (most recent call last):
  File "/home/mchorse/generate.py", line 91, in <module>
    main()
  File "/home/mchorse/generate.py", line 33, in main
    model, neox_args = setup_for_inference_or_eval(use_cache=True)
  File "/home/mchorse/megatron/utils.py", line 443, in setup_for_inference_or_eval
    model, _, _ = setup_model_and_optimizer(
  File "/home/mchorse/megatron/training.py", line 649, in setup_model_and_optimizer
    neox_args.iteration = load_checkpoint(
  File "/home/mchorse/megatron/checkpointing.py", line 239, in load_checkpoint
    checkpoint_name, state_dict = model.load_checkpoint(
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 2599, in load_checkpoint
    load_path, client_states = self._load_checkpoint(load_dir,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 2662, in _load_checkpoint
    self.load_module_state_dict(checkpoint=checkpoint,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1274, in load_module_state_dict
    self.module.load_state_dir(load_dir=self._curr_ckpt_path,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/module.py", line 596, in load_state_dir
    sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/state_dict_factory.py", line 43, in get_sd_loader
    return MegatronSDLoader(ckpt_list, version, checkpoint_engine)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/state_dict_factory.py", line 193, in __init__
    super().__init__(ckpt_list, version, checkpoint_engine)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/state_dict_factory.py", line 55, in __init__
    self.check_ckpt_list()
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/state_dict_factory.py", line 168, in check_ckpt_list
    assert len(self.ckpt_list) > 0
AssertionError

Which I'm pretty sure occurs because when it tries to load the layer_01_* and layer_02_* checkpoint files.

Expected behavior
Checkpoints should load successfully.

Proposed solution
I believe the issue happened by accidentally adding 'layer_i + 2' in two locations instead of the one here and here)

I would just take out the second one, so that the pipeline_parallel version matches more closely to the sequential version.

Environment (please complete the following information):

  • Just running the convert_raw_llama_weights on cpus.
  • Configs:
    my 7B llama config
{
  "pipe_parallel_size": 4,
  "model_parallel_size": 4,
  "make_vocab_size_divisible_by": 1,
  "deepspeed_mpi": True,
  "launcher": "openmpi",
  "finetune": true,

  # model settings
  "num_layers": 32,
  "hidden_size": 4096,
  "num_attention_heads": 32,
  "seq_length": 2048,
  "max_position_embeddings": 2048,
  "pos_emb": "rotary",
  "rotary_pct": 1,
  "no_weight_tying": true,
  "gpt_j_residual": false,
  "output_layer_parallelism": "column",
  "norm": "rmsnorm",
  "rms_norm_epsilon": 1.0e-6,
  

  "scaled_upper_triang_masked_softmax_fusion": true,
  "bias_gelu_fusion": false,
  "use_bias_in_norms": false,
  "use_bias_in_attn_linear": false,
  "mlp_type": "llama",
  "activation": "silu",

  # init methods
   "init_method": "small_init",
   "output_layer_init_method": "wang_init",

   # optimizer settings
   "optimizer": {
     "type": "Adam",
     "params": {
       "lr": 0.00012,
       "betas": [0.9, 0.95],
       "eps": 1.0e-8,
     }
   },

  # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
  "zero_optimization": {
  "stage": 1,
  "allgather_partitions": True,
  "allgather_bucket_size": 500000000,
  "overlap_comm": True,
  "reduce_scatter": True,
  "reduce_bucket_size": 500000000,
  "contiguous_gradients": True,
  },
  "min_lr": 0.000012,

  # batch / data settings
  "train_micro_batch_size_per_gpu": 4,
  "data_impl": "mmap",

  # activation checkpointing
  "checkpoint_activations": true,
  "checkpoint_num_layers": 1,
  "partition_activations": true,
  "synchronize_each_layer": true,

  # regularization
  "gradient_clipping": 1.0,
  "weight_decay": 0.1,
  "hidden_dropout": 0,
  "attention_dropout": 0,

  # precision settings
  "fp16": {
    "fp16": true,
    "enabled": true,
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
  },

  # misc. training settings
  "train_iters": 320000,
  "lr_decay_iters": 320000,
  "distributed_backend": "nccl",
  "lr_decay_style": "cosine",
  "warmup": 0.01,
  "checkpoint_factor": 10000,
  "eval_interval": 1000,
  "eval_iters": 10,

  # logging
  "log_interval": 100,
  "steps_per_print": 10,
  "keep_last_n_checkpoints": 4,
  "wall_clock_breakdown": true,

  "tokenizer_type": "SPMTokenizer"

}
@CRSilkworth CRSilkworth added the bug Something isn't working label Jun 9, 2023
@HuangLK
Copy link
Contributor

HuangLK commented Jun 10, 2023

try to remove the "+2" of this line.

torch.save(obj, self.save_path(layer_i=layer_i + 2, rank=rank))

@StellaAthena
Copy link
Member

@CRSilkworth can you check if the code on the llama-conversion branch works for you?

@CRSilkworth
Copy link
Author

@StellaAthena It looks like that solves the original issue but there is another somewhat unrelated issue, which I believe is due a deepspeed update that gets pulled when installing gpt-neox from scratch. It looks like this line was added in the latest deepspeed, which assumes a 'module' key in the checkpoint dict.

Traceback (most recent call last):
  File "/home/mchorse/train.py", line 27, in <module>
    pretrain(neox_args=neox_args)
  File "/home/mchorse/megatron/training.py", line 192, in pretrain
    model, optimizer, lr_scheduler = setup_model_and_optimizer(
  File "/home/mchorse/megatron/training.py", line 661, in setup_model_and_optimizer
    neox_args.iteration = load_checkpoint(
  File "/home/mchorse/megatron/checkpointing.py", line 239, in load_checkpoint
    checkpoint_name, state_dict = model.load_checkpoint(
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 2599, in lo
ad_checkpoint
    load_path, client_states = self._load_checkpoint(load_dir,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 2662, in _l
oad_checkpoint
    self.load_module_state_dict(checkpoint=checkpoint,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1271, 
in load_module_state_dict
    super().load_module_state_dict(state_dict, strict)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 2458, in lo
ad_module_state_dict
    module_state_dict = checkpoint['module']
KeyError: 'module'

I can get it to load if I set this line to 'None' instead of an empty dict. Not sure if that's kosher? I suspect there is some kind of recursion going on, although I'm not very familiar with this code so it's a little hard to follow.

@StellaAthena
Copy link
Member

Probably a question best posed to @Quentin-Anthony

@Quentin-Anthony
Copy link
Member

I'll take a look.

@haileyschoelkopf
Copy link
Contributor

@CRSilkworth This error should be able to be fixed by either passing --pipe_parallel to tools/convert_raw_llama_weights_to_neox.py, or by setting "pipe-parallel-size": 0 in your LLaMA training config--it typically means that your LLaMA module is in the sequential() format that is only used when setting pipeline parallel size to 0 in the most up-to-date version of the code.

I can make a PR to make --pipe_parallel on by default!

@CRSilkworth
Copy link
Author

@haileyschoelkopf Actually, this error occurs when setting --pipeline_parallel for tools/convert_raw_llama_weights_to_neox.py and then running with pipe_parallel_size > 1.

@Quan-Sun
Copy link

I got the same error. Is there any updates?

@haileyschoelkopf
Copy link
Contributor

Yes, the most recent version ( #1124 ) of the conversion script should no longer have this error—have tested both round-trip conversion and training.

@linjiadegou2
Copy link

@StellaAthena It looks like that solves the original issue but there is another somewhat unrelated issue, which I believe is due a deepspeed update that gets pulled when installing gpt-neox from scratch. It looks like this line was added in the latest deepspeed, which assumes a 'module' key in the checkpoint dict.

Traceback (most recent call last):
  File "/home/mchorse/train.py", line 27, in <module>
    pretrain(neox_args=neox_args)
  File "/home/mchorse/megatron/training.py", line 192, in pretrain
    model, optimizer, lr_scheduler = setup_model_and_optimizer(
  File "/home/mchorse/megatron/training.py", line 661, in setup_model_and_optimizer
    neox_args.iteration = load_checkpoint(
  File "/home/mchorse/megatron/checkpointing.py", line 239, in load_checkpoint
    checkpoint_name, state_dict = model.load_checkpoint(
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 2599, in lo
ad_checkpoint
    load_path, client_states = self._load_checkpoint(load_dir,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 2662, in _l
oad_checkpoint
    self.load_module_state_dict(checkpoint=checkpoint,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1271, 
in load_module_state_dict
    super().load_module_state_dict(state_dict, strict)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 2458, in lo
ad_module_state_dict
    module_state_dict = checkpoint['module']
KeyError: 'module'

I can get it to load if I set this line to 'None' instead of an empty dict. Not sure if that's kosher? I suspect there is some kind of recursion going on, although I'm not very familiar with this code so it's a little hard to follow.

I also encountered this problem. My deepspeed branch is bf16_zero1, and I found some changes in the new version of the code. If you know how to modify it, please teach me. Thank you.

@haileyschoelkopf
Copy link
Contributor

@linjiadegou2 when running the convert_raw_llama_weights_to_neox.py script, if you do not pass --pipeline_parallel you must set pipe-parallel-size: 0 in your YML neox config, and if you do pass --pipeline_parallel you must set pipe-parallel-size to >= 1.

If pipeline parallel size is set to 0, then the checkpoint save/load format is different and neox tries to load from this "module" key, whereas if pipeline parallel is being used then the weights are saved and loaded from per-layer files.

@linjiadegou2
Copy link

@linjiadegou2 when running the script, if you do not pass you must set in your YML neox config, and if you do pass you must set to >= 1.convert_raw_llama_weights_to_neox.py``--pipeline_parallel``pipe-parallel-size: 0``--pipeline_parallel``pipe-parallel-size

If pipeline parallel size is set to 0, then the checkpoint save/load format is different and neox tries to load from this "module" key, whereas if pipeline parallel is being used then the weights are saved and loaded from per-layer files.

I converted the raw llama2 parameters to a format supported by NEOX using convert_raw_llama_weights_to_neox.py and I used --pipeline_parallel, and in my configuration file, "pipe_parallel_size" : 1. The problem still arises.

@haileyschoelkopf
Copy link
Contributor

Could you open a new issue for this? I'll have to try to replicate this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants