Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(modeling): deepspeed checkpoint loading #482

Merged
merged 15 commits into from
Aug 8, 2023
Merged

Conversation

maxreciprocate
Copy link
Collaborator

@maxreciprocate maxreciprocate commented May 22, 2023

This PR adds an option to load and resume training from previously saved deepspeed checkpoints

Example of successive training:

accelerate launch --config_file configs/accelerate/zero3.yaml examples/ppo_sentiments.py
accelerate launch --config_file configs/accelerate/zero3.yaml examples/ppo_sentiments.py '{"train": {"resume_from_checkpoint": "ckpts/best_checkpoint"}}'
  • Verify ZeRO3 loading
  • Update docstrings

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this, Max!

Checkpointing works properly for models that do not freeze any layers (at least under ZeRO-2) but not otherwise. E.g. for a model with num_layers_unfrozen=2 I get the following error.

RuntimeError: Error(s) in loading state_dict for AutoModelForCausalLMWithHydraValueHead:
        Missing key(s) in state_dict: "frozen_head.decoder_blocks.0.ln_1.weight", 
"frozen_head.decoder_blocks.0.ln_1.bias", "frozen_head.decoder_blocks.0.attn.bias", 
"frozen_head.decoder_blocks.0.attn.masked_bias", "frozen_head.decoder_blocks.0.attn.c_attn.weight", 
"frozen_head.decoder_blocks.0.attn.c_attn.bias", "frozen_head.decoder_blocks.0.attn.c_proj.weight", 
"frozen_head.decoder_blocks.0.attn.c_proj.bias", "frozen_head.decoder_blocks.0.ln_2.weight", 
"frozen_head.decoder_blocks.0.ln_2.bias", "frozen_head.decoder_blocks.0.mlp.c_fc.weight", 
"frozen_head.decoder_blocks.0.mlp.c_fc.bias", "frozen_head.decoder_blocks.0.mlp.c_proj.weight", 
"frozen_head.decoder_blocks.0.mlp.c_proj.bias", "frozen_head.decoder_blocks.1.ln_1.weight", 
"frozen_head.decoder_blocks.1.ln_1.bias", "frozen_head.decoder_blocks.1.attn.bias", 
"frozen_head.decoder_blocks.1.attn.masked_bias", "frozen_head.decoder_blocks.1.attn.c_attn.weight", 
"frozen_head.decoder_blocks.1.attn.c_attn.bias", "frozen_head.decoder_blocks.1.attn.c_proj.weight", 
"frozen_head.decoder_blocks.1.attn.c_proj.bias", "frozen_head.decoder_blocks.1.ln_2.weight", 
"frozen_head.decoder_blocks.1.ln_2.bias", "frozen_head.decoder_blocks.1.mlp.c_fc.weight", 
"frozen_head.decoder_blocks.1.mlp.c_fc.bias", "frozen_head.decoder_blocks.1.mlp.c_proj.weight", 
"frozen_head.decoder_blocks.1.mlp.c_proj.bias", "frozen_head.final_norm.weight", "frozen_head.final_norm.bias", 
"frozen_head.lm_head.weight". 

We'll probably need to filter frozen_head from state_dict when checkpointing. What do you think?

trlx/models/modeling_ilql.py Outdated Show resolved Hide resolved
trlx/models/modeling_ilql.py Outdated Show resolved Hide resolved
for k, v in v_head_state_dict.items():
base_model_state_dict[f"v_head.{k}"] = v
return base_model_state_dict
base_model_state_dict = self.base_model.state_dict(*args, **dict(prefix="base_model.", **kwargs))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for ILQL - need to add support for Seq2Seq.

trlx/trlx.py Show resolved Hide resolved
Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look good to me!

@Dahoas
Copy link
Collaborator

Dahoas commented Jul 10, 2023

@maxreciprocate Do you want to resolve the conflicts and then we will merge?

@Dahoas
Copy link
Collaborator

Dahoas commented Aug 4, 2023

@maxreciprocate Bump on this

@Dahoas
Copy link
Collaborator

Dahoas commented Aug 8, 2023

Looks good, merging

@Dahoas Dahoas merged commit 2e667e6 into main Aug 8, 2023
2 checks passed
@maxreciprocate maxreciprocate deleted the fix-checkpoint-loading branch August 8, 2023 18:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants