Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix floating point precision issue for RoPE #23837

Closed
wants to merge 2 commits into from

Conversation

butsugiri
Copy link

What does this PR do?

This PR fixes the issue of floating point precision in RotaryEmbedding.
The purpose of this PR is to fix inconsistency between GPT-Neo-X and HF Transformers, which is causing a model performance degradation.

Issue

In the current implementation of RotaryEmbedding, inv_freq is first initialized by float32.
This value is then used for initializing cos_cached and sin_cached by float32.
As a result, cos_cached and sin_cached remain float32 even if the model (including inv_freq) uses float16; this is because these two variables are not the target of dtype conversion of half() method
Note that there is also a recomputation logic for these two variables, but it is very unlikely to occur

# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.

However, this implementation seems inconsistent to the one in the EleutherAI/gpt-neox library.
In their implementation, cos_cached and sin_cached are almost always recomputed in the forward method.
Thus, dtype of cos_cached and sin_cached are always consistent to the dtype of inv_freq.

This inconsistency between two libraries (HF Transformers and GPT-Neo-X) causes the performance degradation of the model converted from gpt-neox.
For example, the perplexity score of the language model on Wikitext corpus is as follows:

  • gpt-neo-x w/o conversion: 520.7840
  • gpt-neo-x w/ conversion to HF format: 520.9911
  • gpt-neo-x w/ conversion to HF format and this PR: 520.7840

(Sorry that the perplexity value is really bad. I am reporting the performance of model trained on toy data for debugging purpose)

Solution

I basically followed the previous PR #22888 and made a similar fix.

Possible Side Effect

In the original code, cos_cashed and sin_cashed are initialized in the model consturctor.
However, I had to move the initialization code to forward method.
Otherwise the library gave me the following error: "cos_vml_cpu" not implemented for 'Half'.
As a result, torch.jit.trace might be no longer available.
Since I am not sure what jit.trace is, I don't have any workaround for this.

Similar Issues

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?
    • I would really appreciate it if the reviewers could point out the missing tests.

Who can review?

@ArthurZucker and @younesbelkada

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR but we cannot break the model like this as it's used by multiple checkpoints on the hub.

@butsugiri
Copy link
Author

Thank you for the message.
While I appreciate that we have to keep the compatibility with existing models on the hub, my understanding is that all existing models converted from NeoX all have this precision issue.
I would like to explore alternative solutions to address this issue rather than simply closing the pull request. Is there any other approach we can consider to fix the problem?

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jul 6, 2023
@ArthurZucker
Copy link
Collaborator

Hey! If you want to fix the problem without having to close the PR you should be aiming for a full backward compatibility, add tests to make sure that you are fixing the issue in place, and that previous behaviour is not broken.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants