Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HF Checkpoint Format Support for Llama Vision #1727

Merged
merged 4 commits into from
Oct 1, 2024

Conversation

pbontrager
Copy link
Contributor

@pbontrager pbontrager commented Oct 1, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Added hf_to_tune and tune_to_hf mappings.

Changelog

  • updated the llama3_2_vision/_convert_weights.py to include the mapping
  • updated the _checkpointer.py to call these

Test plan

  • Tested running 1 epoch of size 1 with the hf checkpoint format tune run full_finetune_single_device --config llama3_2_vision/11B_full_single_device max_steps_per_epoch=1
  • Verified meta_to_tune(meta_checkpoint) == hf_to_tune(hf_checkpoint)
  • Verified tune_to_hf(hf_to_tune(hf_checkpoint)) == hf_checkpoint
  • Verified tune_to_meta(meta_to_tune(meta_checkpoint)) == meta_checkpoint
  • HF training run showed comparable loss curve to Meta

Copy link

pytorch-bot bot commented Oct 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1727

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7efac3e with merge base 10b02e0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 1, 2024
) -> Dict[str, torch.Tensor]:
"""
Convertor from HF state dict to torchtune state dict. This handles:
- Updateing the cross attention layer numbers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Updateing the cross attention layer numbers
- Updating the cross attention layer numbers

"tile_pos_embed.embedding" in new_key
or "global_token_positional_embedding" in new_key
):
# WARNING
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah please update this comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to finish it.

) -> Dict[str, torch.Tensor]:
"""
Convertor from Tune state dict to HF state dict. This handles:
- Updateing the cross attention layer numbers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Updateing the cross attention layer numbers
- Updating the cross attention layer numbers

num_heads=text_config["num_attention_heads"],
num_kv_heads=text_config["num_key_value_heads"],
dim=text_config["hidden_size"],
head_dim=text_config.get("head_dim", None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do some of these have defaults for .get and some do not have defaults? could text_config and vision_config ever be empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, I was following the base implementation and all lists defaulted to None.

new_key = get_mapped_key(key, _FROM_HF)
if "language_model" in key:
if "layers" in key: # Update layer numbers
layer = int(key.split(".")[3])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the end of the world, but I feel like half the reason we added the placeholders in e.g. _FROM_HF was so that we could avoid stuff like this (e.g. here) and keep the hf_to_tune methods simpler. I know the multimodal models are more involved, so maybe there's no way around it. But still, all of this is a bit hard to parse

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a number of challenges here that our setup doesn't handle out of the box and I'm not quite sure if there's a good generalization. The layer numbering scheme for cross_attention is different between ours and theirs. There are also a number of instances where a single one of their keys maps to two of our keys. We could solve the latter generally, but I'm not sure about the former, we might need some layer mapping too. If we add more MM models, we might want to go down that route.

Comment on lines +470 to +471
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_hf_to_tune,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for gating the import here? Due to transitive torchvision dep or something?

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re testing, can you add a check that running with Meta checkpointer and HF checkpointer give the same loss curves?

Comment on lines 298 to 301
elif new_key == "decoder.tok_embeddings.weight":
learned_embedding = "decoder.tok_embeddings.fusion_embedding.weight"
converted_state_dict[learned_embedding] = value[vocab_size:]
value = value[:vocab_size]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment for this

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple small comments, the main one is that I'd like to see loss curves for the two checkpointers as a sanity check. Stamping to unblock

@pbontrager pbontrager merged commit 3c91e42 into pytorch:main Oct 1, 2024
17 checks passed
@pbontrager pbontrager deleted the hf_checkpointer branch October 1, 2024 20:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants