-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix LoRA weight merging for FSDP, integrate checkpointer #506
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/506
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4ab1508 with merge base 9c75d48 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the quick turn around on this. I'll let you make sure you add the unit tests and the test for resume_from_checkpoint=True
before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG overall, will wait for tests/docs etc
recipes/lora_finetune_distributed.py
Outdated
msg=f"Model checkpoint of size {os.path.getsize(output_loc) >> 20} MB saved to {output_loc}" | ||
) | ||
# Construct the full state dict with LoRA weights merged into base LLM weights | ||
sd = {k: v.cpu() for k, v in self._model.state_dict().items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why making it on the CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing on GPU you'll need to clone and create a copy otherwise you'll update the base model weights? And if you clone you'll double the memory? Anyways torch.save will write to CPU buffer before writing to file so might not be that much of an overhead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's to avoid making a copy. Lmk if there's an obviously better way to do this though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought, what do you think about doing the .cpu() call inside the function? Otherwise if someone calls this their state_dict can get modified without realizing, which is probably not the desired effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does "this" mean here?
TBH, I do like that this is explicitly called in the recipe (maybe you need to document this better). The cpu move is to explicitly avoid memory overhead and thats a recipe thing not a merge thinig. The merge function should be usable on whichever device I want. And I should be responsible for using it appropriately. Does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, this here == get_merged_lora_ckpt. But yeah, based on that rationale it makes sense to do in the recipe then
torchtune/modules/peft/peft_utils.py
Outdated
and callable(m._merge_lora_weights) | ||
and hasattr(m, "_unmerge_lora_weights") | ||
and callable(m._unmerge_lora_weights) | ||
def get_lora_modules(state_dict: Dict[str, Any]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
next 2 are public API, do we want to docstring etc/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prob gonna make at least _get_lora_modules private, but will add docstring for both regardless.
if ( | ||
self.seed != ckpt_dict[utils.SEED_KEY] | ||
or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] | ||
or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY] | ||
): | ||
warn( | ||
message="""Configured value for seed, epochs or max_steps_per_epoch | ||
does not match the value stored in checkpoint.""" | ||
) | ||
self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) | ||
self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] | ||
self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] | ||
self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I did for full finetune but forgot to do for lora single device was wrap this around a try-catch block and have a more meaningful exception than just KeyError. Now that I think about it, I dont know how meaningful that actually was. Will let you see if thats worth adding to LoRA as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did something in between the two, lmk if it looks reasonable to you.
Context
Our previous approach to merging LoRA weights based on state dict hooks did not work with FSDP. Unfortunately we missed this because our recipe test only ran FSDP on a single device and so we missed a tricky bug that only manifests when the strategy is
FULL_SHARD
.The bug is due to different params for LoRALinear being wrapped in different FSDP blocks. We wrap the LoRA A and B matrices each in their own blocks, because FSDP will allocate memory for grads at the FSDP block level (so if we wrapped them in the usual way inside our
TransformerDecoderLayer
, memory would be allocated for all the other frozen params in that module, kinda defeating the whole memory-saving aspect of LoRA).Unfortunately, when we perform the weight merge operations on the sharded LoRA params, we need to call
summon_full_params
to gather the full parameters to the device (as opposed to the flat params kept by FSDP). This needs to be called on each FSDP instance involved in the operation. InLoRALinear
this would beself.lora_a
,self.lora_b
, and the FSDP instance containingself.weight
. As mentioned above,self.lora_a
andself.lora_b
are wrapped in their own FSDP blocks. Butself.weight
is only wrapped in the parent moduleTransformerDecoderLayer
. This means we cannot define the weight merge hooks onLoRALinear
, and would instead need to define them onTransformerDecoderLayer
. This would be quite unintuitive, so we are gonna go another route.Changelog
model.state_dict()
, move it to CPU, infer the LoRA modules (only based on state dict keys), then perform the weight merge directly in the dictionary, throwing out the correspondinglora_a
andlora_b
keys when we're done.Test plan
test_get_merged_lora_ckpt
to confirm expected weights and forward pass parity after merging on a very simple model.test_training_state_on_resume
andtest_save_and_load_merged_weights
. Both tests run for single device and distributed recipe.FULL_SHARD
. But once we have distributed CI setup we should change this.test_full_finetune.py
.Also confirmed that FSDP on multiple devices (i.e. with
FULL_SHARD
) works viaSeparately confirmed max memory by printing memory stats using the utility from #391.