Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pythia Checkpoint Loading #4

Closed
kshitijkg opened this issue May 30, 2023 · 1 comment
Closed

Pythia Checkpoint Loading #4

kshitijkg opened this issue May 30, 2023 · 1 comment
Assignees
Labels
bug Something isn't working enhancement New feature or request
Milestone

Comments

@kshitijkg
Copy link
Member

kshitijkg commented May 30, 2023

We need to load Pythia Checkpoints for MAGMA training.
Main Issue: Mismatch in weights in checkpoint and in MAGMA model
Sources of mismatch

  1. Naming change due to Attention module being re-set to the AdapterWrapper (https://github.com/floatingsnake/gpt-neox/blob/magma/megatron/model/adapter.py#L141), resulting in weights changing from, example:
    2.attention. query_key_value.weight to 2.attention.attn_block.query_key_value.weight

Proposed solutions:
Without changing names on Pythia Checkpoint:

  1. Add adapters after loading checkpoint, the restructuring happens after weights have already bene loaded. Disadvantage: Adapter weights will have to be loaded separately, Disadvantage: code will duplicated and not clean
  2. Get class from the module https://github.com/floatingsnake/gpt-neox/blob/magma/megatron/model/adapter.py#L129, then inherit it, override init and forward functions to include adapters. The structure remains the same, but this does not work since we dont easily have the initialization arguments to recreate the attention module. We only have the initialized object

Changing the names of the Pythia Checkpoint:

  1. Renames the weights from attention to attention.attn_block and mlp to mlp.attn_block, and stores the checkpoint again, and use the new checkpoint.
  2. Override with custom load fn that does this on the fly: Pythia checkpoint loading #3: This solution will not work in the future when we are using pipeline parallelism: custom_load_fn not supported w. pipeline parallelism

Mismatch Source 2:
2. Additional weights in MAGMA - Due to image prefix and adapters:
Proposed Solution: Can be resolved by setting strict = False when loading checkpoint. Not the best solution, can be risky. But plan is to quickly verify if all the weights that dont match are just due to image prefix and adapters and be able to train stuff, after First mismatch has been fixed, set strict=False. Can find a better solution once everyone is able to use the code to port their changes and do test runs.

@kshitijkg kshitijkg added the bug Something isn't working label May 30, 2023
@kshitijkg kshitijkg self-assigned this May 30, 2023
@kshitijkg kshitijkg pinned this issue May 30, 2023
@kshitijkg kshitijkg unpinned this issue May 30, 2023
@floatingbigcat floatingbigcat added the enhancement New feature or request label May 31, 2023
@kshitijkg
Copy link
Member Author

kshitijkg commented May 31, 2023

Current Solution: Number 3. Renames the weights from attention to attention.attn_block and mlp to mlp.attn_block, and stores the checkpoint again, and use the new checkpoint.
PR: #10

We just need to run the convert checkpoint script and use that to load.

Additionally, we set strict = False so that image prefix and adapters are ignored. I have checked manually if there are any other weights that exist that dont have the right name, but everything looks correct.

Lastly, this requires another change in the DeeperSpeed code, use the following branch: https://github.com/EleutherAI/DeeperSpeed/tree/robin_summit

@kshitijkg kshitijkg added this to the Robin V0 milestone Jun 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants