-
Notifications
You must be signed in to change notification settings - Fork 375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add HF Checkpoint Format Support for Llama Vision #1727
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1727
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7efac3e with merge base 10b02e0 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
) -> Dict[str, torch.Tensor]: | ||
""" | ||
Convertor from HF state dict to torchtune state dict. This handles: | ||
- Updateing the cross attention layer numbers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Updateing the cross attention layer numbers | |
- Updating the cross attention layer numbers |
"tile_pos_embed.embedding" in new_key | ||
or "global_token_positional_embedding" in new_key | ||
): | ||
# WARNING |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah please update this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to finish it.
) -> Dict[str, torch.Tensor]: | ||
""" | ||
Convertor from Tune state dict to HF state dict. This handles: | ||
- Updateing the cross attention layer numbers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Updateing the cross attention layer numbers | |
- Updating the cross attention layer numbers |
num_heads=text_config["num_attention_heads"], | ||
num_kv_heads=text_config["num_key_value_heads"], | ||
dim=text_config["hidden_size"], | ||
head_dim=text_config.get("head_dim", None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do some of these have defaults for .get
and some do not have defaults? could text_config
and vision_config
ever be empty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know, I was following the base implementation and all lists defaulted to None.
new_key = get_mapped_key(key, _FROM_HF) | ||
if "language_model" in key: | ||
if "layers" in key: # Update layer numbers | ||
layer = int(key.split(".")[3]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not the end of the world, but I feel like half the reason we added the placeholders in e.g. _FROM_HF
was so that we could avoid stuff like this (e.g. here) and keep the hf_to_tune
methods simpler. I know the multimodal models are more involved, so maybe there's no way around it. But still, all of this is a bit hard to parse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a number of challenges here that our setup doesn't handle out of the box and I'm not quite sure if there's a good generalization. The layer numbering scheme for cross_attention is different between ours and theirs. There are also a number of instances where a single one of their keys maps to two of our keys. We could solve the latter generally, but I'm not sure about the former, we might need some layer mapping too. If we add more MM models, we might want to go down that route.
from torchtune.models.llama3_2_vision._convert_weights import ( | ||
llama3_vision_hf_to_tune, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for gating the import here? Due to transitive torchvision dep or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re testing, can you add a check that running with Meta checkpointer and HF checkpointer give the same loss curves?
elif new_key == "decoder.tok_embeddings.weight": | ||
learned_embedding = "decoder.tok_embeddings.fusion_embedding.weight" | ||
converted_state_dict[learned_embedding] = value[vocab_size:] | ||
value = value[:vocab_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple small comments, the main one is that I'd like to see loss curves for the two checkpointers as a sanity check. Stamping to unblock
Context
What is the purpose of this PR? Is it to
Added hf_to_tune and tune_to_hf mappings.
Changelog
Test plan
tune run full_finetune_single_device --config llama3_2_vision/11B_full_single_device max_steps_per_epoch=1