Fix floating point precision issue for RoPE #23837
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR fixes the issue of floating point precision in
RotaryEmbedding
.The purpose of this PR is to fix inconsistency between GPT-Neo-X and HF Transformers, which is causing a model performance degradation.
Issue
In the current implementation of
RotaryEmbedding
,inv_freq
is first initialized by float32.This value is then used for initializing
cos_cached
andsin_cached
by float32.As a result,
cos_cached
andsin_cached
remain float32 even if the model (including inv_freq) uses float16; this is because these two variables are not the target of dtype conversion ofhalf()
methodNote that there is also a recomputation logic for these two variables, but it is very unlikely to occur
transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Line 268 in f67dac9
However, this implementation seems inconsistent to the one in the EleutherAI/gpt-neox library.
In their implementation,
cos_cached
andsin_cached
are almost always recomputed in the forward method.Thus, dtype of
cos_cached
andsin_cached
are always consistent to the dtype of inv_freq.This inconsistency between two libraries (HF Transformers and GPT-Neo-X) causes the performance degradation of the model converted from gpt-neox.
For example, the perplexity score of the language model on Wikitext corpus is as follows:
(Sorry that the perplexity value is really bad. I am reporting the performance of model trained on toy data for debugging purpose)
Solution
I basically followed the previous PR #22888 and made a similar fix.
Possible Side Effect
In the original code,
cos_cashed
andsin_cashed
are initialized in the model consturctor.However, I had to move the initialization code to forward method.
Otherwise the library gave me the following error: "cos_vml_cpu" not implemented for 'Half'.
As a result,
torch.jit.trace
might be no longer available.Since I am not sure what jit.trace is, I don't have any workaround for this.
Similar Issues
Before submitting
Pull Request section?
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker and @younesbelkada