Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LoRA weight merging for FSDP, integrate checkpointer #506

Merged
merged 3 commits into from
Mar 15, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Mar 15, 2024

Context

Our previous approach to merging LoRA weights based on state dict hooks did not work with FSDP. Unfortunately we missed this because our recipe test only ran FSDP on a single device and so we missed a tricky bug that only manifests when the strategy is FULL_SHARD.

The bug is due to different params for LoRALinear being wrapped in different FSDP blocks. We wrap the LoRA A and B matrices each in their own blocks, because FSDP will allocate memory for grads at the FSDP block level (so if we wrapped them in the usual way inside our TransformerDecoderLayer, memory would be allocated for all the other frozen params in that module, kinda defeating the whole memory-saving aspect of LoRA).

Unfortunately, when we perform the weight merge operations on the sharded LoRA params, we need to call summon_full_params to gather the full parameters to the device (as opposed to the flat params kept by FSDP). This needs to be called on each FSDP instance involved in the operation. In LoRALinear this would be self.lora_a, self.lora_b, and the FSDP instance containing self.weight. As mentioned above, self.lora_a and self.lora_b are wrapped in their own FSDP blocks. But self.weight is only wrapped in the parent module TransformerDecoderLayer. This means we cannot define the weight merge hooks on LoRALinear, and would instead need to define them on TransformerDecoderLayer. This would be quite unintuitive, so we are gonna go another route.

Changelog

  • Instead of relying on state dict hooks, we do the dumb thing and directly perform weight merging on the fully-realized state dict.
    • To save merged weights, we call model.state_dict(), move it to CPU, infer the LoRA modules (only based on state dict keys), then perform the weight merge directly in the dictionary, throwing out the corresponding lora_a and lora_b keys when we're done.
    • It's not pretty, but it works with FSDP and can be done in around ~14GB of memory when saving in bf16.
  • Also the new checkpointer APIs make the UX of this a lot less clunky. Since we were gonna migrate this recipe to them eventually anyways, what better time to do it than now?

Test plan

  • Add unit test test_get_merged_lora_ckpt to confirm expected weights and forward pass parity after merging on a very simple model.
  • Add recipe tests test_training_state_on_resume and test_save_and_load_merged_weights. Both tests run for single device and distributed recipe.
    • Note that distributed recipe still only runs on single device and so is not using FULL_SHARD. But once we have distributed CI setup we should change this.
    • The first test is similar to the analogous test in test_full_finetune.py.
    • The second test explicitly loads checkpointed adapter weights + base model weights into a LoRA Llama2 class, then loads the merged LoRA checkpoint into a base model class and checks forward parity between the two.

Also confirmed that FSDP on multiple devices (i.e. with FULL_SHARD) works via

tune --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config alpaca_llama2_lora_finetune_distributed checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer checkpointer.checkpoint_dir=/data/users/ebs/checkpoints checkpointer.checkpoint_files=['llama2-7b-torchtune.pt'] tokenizer.path=/data/users/ebs/checkpoints/lora-debug/tokenizer.model model.lora_attn_modules=['q_proj','k_proj','v_proj','output_proj'] model.apply_lora_to_mlp=True model.apply_lora_to_output=True model.lora_rank=8 model.lora_alpha=16 batch_size=2 max_steps_per_epoch=2 full_bf16=True checkpointer.output_dir=/data/users/ebs/lora-debug/test-weight-merge-fix
...
Model checkpoint of size 26.95 GB saved to /data/users/ebs/lora-debug/test-weight-merge-fix/torchtune_model_0.pt
Adapter checkpoint of size 0.08 GB saved to /data/users/ebs/lora-debug/test-weight-merge-fix/adapter_0.pt
Model checkpoint of size 26.95 GB saved to /data/users/ebs/lora-debug/test-weight-merge-fix/torchtune_model_0.pt
Adapter checkpoint of size 0.08 GB saved to /data/users/ebs/lora-debug/test-weight-merge-fix/adapter_0.pt

Separately confirmed max memory by printing memory stats using the utility from #391.

Copy link

pytorch-bot bot commented Mar 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/506

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4ab1508 with merge base 9c75d48 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 15, 2024
Copy link

netlify bot commented Mar 15, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit 4ab1508
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65f47d638661650009d3648c
😎 Deploy Preview https://deploy-preview-506--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the quick turn around on this. I'll let you make sure you add the unit tests and the test for resume_from_checkpoint=True before merging.

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG overall, will wait for tests/docs etc

msg=f"Model checkpoint of size {os.path.getsize(output_loc) >> 20} MB saved to {output_loc}"
)
# Construct the full state dict with LoRA weights merged into base LLM weights
sd = {k: v.cpu() for k, v in self._model.state_dict().items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why making it on the CPU?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing on GPU you'll need to clone and create a copy otherwise you'll update the base model weights? And if you clone you'll double the memory? Anyways torch.save will write to CPU buffer before writing to file so might not be that much of an overhead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's to avoid making a copy. Lmk if there's an obviously better way to do this though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, what do you think about doing the .cpu() call inside the function? Otherwise if someone calls this their state_dict can get modified without realizing, which is probably not the desired effect.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "this" mean here?

TBH, I do like that this is explicitly called in the recipe (maybe you need to document this better). The cpu move is to explicitly avoid memory overhead and thats a recipe thing not a merge thinig. The merge function should be usable on whichever device I want. And I should be responsible for using it appropriately. Does that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this here == get_merged_lora_ckpt. But yeah, based on that rationale it makes sense to do in the recipe then

and callable(m._merge_lora_weights)
and hasattr(m, "_unmerge_lora_weights")
and callable(m._unmerge_lora_weights)
def get_lora_modules(state_dict: Dict[str, Any]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next 2 are public API, do we want to docstring etc/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prob gonna make at least _get_lora_modules private, but will add docstring for both regardless.

Comment on lines +116 to +128
if (
self.seed != ckpt_dict[utils.SEED_KEY]
or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]
or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]
):
warn(
message="""Configured value for seed, epochs or max_steps_per_epoch
does not match the value stored in checkpoint."""
)
self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY])
self.epochs_run = ckpt_dict[utils.EPOCHS_KEY]
self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY]
self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I did for full finetune but forgot to do for lora single device was wrap this around a try-catch block and have a more meaningful exception than just KeyError. Now that I think about it, I dont know how meaningful that actually was. Will let you see if thats worth adding to LoRA as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did something in between the two, lmk if it looks reasonable to you.

@ebsmothers ebsmothers merged commit 25286f4 into main Mar 15, 2024
21 checks passed
@ebsmothers ebsmothers deleted the fix-lora_weight-merge-fsdp branch March 15, 2024 17:05
@ebsmothers ebsmothers changed the title [WIP] Fix LoRA weight merging for FSDP, integrate checkpointer Fix LoRA weight merging for FSDP, integrate checkpointer Mar 15, 2024
@ebsmothers ebsmothers mentioned this pull request Mar 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants