Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RotaryEmbedding computation is wrong for certain position/feature pairs in reduced precision (both fp16 and bfloat) #1003

Closed
cbcase opened this issue Jul 27, 2023 · 8 comments · Fixed by #1041
Labels
bug Something isn't working

Comments

@cbcase
Copy link

cbcase commented Jul 27, 2023

Describe the bug
The RotaryEmbedding module does substantially all of the computation of the cached cos and sin tables in whatever is the model precision (usually fp16 or bfloat16). For certain (position, feature) pairs, this produces wildly different values than the corresponding fp32 computation.

To Reproduce
Here is a small reproducer:

seqlen = 2048
dim = 32
base = 10000

inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to("cuda") / dim))
t = torch.arange(seqlen, device="cuda").type_as(inv_freq)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
print(emb[1857, 1])
cos_float = emb.cos()

inv_freq = inv_freq.half()
t = torch.arange(seqlen, device="cuda").type_as(inv_freq)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
print(emb[1857, 1])
cos_half = emb.cos()

print((cos_half.float() - cos_float).abs().max())

On my machine, I get:

tensor(1044.2678, device='cuda:0')
tensor(1045., device='cuda:0', dtype=torch.float16)
tensor(0.7819, device='cuda:0')

As you can see, the issue is that some of the outer product values (emb) are large, so the small relative rounding to fp16 is large in absolute magnitude compared to the period of cos / sin. Note that the issue is relatively worse in bfloat, since it has less precision (across a wider range).

Expected behavior
The computed embeddings shouldn't depend on model precision (up to the point of rounding the computed cos/sin tables)

Proposed solution
Lots of reasonable ways to rework this. I expect the main pain is that inv_freq is a buffer stored in people's checkpoints with the model dtype.

@cbcase cbcase added the bug Something isn't working label Jul 27, 2023
@zhangir-azerbayev
Copy link
Contributor

Noting this as a possible cause of DeepSpeed issue #3742

@zhangir-azerbayev
Copy link
Contributor

Isn't NeoX already doing it the cos_float way for bf16? Note the definition of RotaryEmbedding in NeoX. The computation of self.cos_cached and self.sin_cached is in float32 when self.precision=torch.bfloat16.

@cbcase
Copy link
Author

cbcase commented Jul 27, 2023

The cast-to-float happens after emb is already computed in bfloat (https://github.com/EleutherAI/gpt-neox/blob/math-lm/megatron/model/positional_embeddings.py#L56), and the issue (as in the repro above) is that the rounding-to-16bit of emb screws up the computed cos/sin values. The cos/sin functions themselves work just fine in any dtype (though I assume there's some jit issue with bfloat given the code)

@zhangir-azerbayev
Copy link
Contributor

Doesn't the type_as(self.inv_freq) on line 53 mean that t, and consequently freqs and emb are in fp32? It looks like the if block starting at line 56 is redundant.

@cbcase
Copy link
Author

cbcase commented Jul 27, 2023

self.inv_freq will be cast to the model dtype during deepspeed.initialize, so the dtype for inv_freq changes from fp32 to reduced precision between construction and the first call to forward when the cos/sin tables are computed.

@zhangir-azerbayev
Copy link
Contributor

zhangir-azerbayev commented Jul 27, 2023

So Deepspeed casts the module buffer to the model dtype when deepspeed.initialize is called? I see, where is this documented? Would setting inv_freq = self.inv_freq.float() during the forward pass be enough or do we need to somehow save/reinstantiate self.inv_freq in float32?

@cbcase
Copy link
Author

cbcase commented Jul 28, 2023

I don't know about documentation, but you can see it in DeepspeedEngine here: https://github.com/microsoft/DeepSpeed/blob/46784cb58edf7bbe9b6bbec95212de7b81e55b01/deepspeed/runtime/engine.py#L1142

The best fix is probably to do all the cos/sin table calculation and caching in the constructor when you can be sure about dtypes (or insert appropriate casts in _prepare_cache)

@zhangir-azerbayev
Copy link
Contributor

zhangir-azerbayev commented Aug 13, 2023

@cbcase How does this implementation look? I admit it's quite hacky because I wanted to preserve backward compatibility of checkpoints. Also, right now when seq_len changes it computes a new cos/sin table instead of doing a slice. If this looks good to you, I can turn it into a version that's ready for main. https://github.com/EleutherAI/gpt-neox/blob/math-lm-2-rotary/megatron/model/positional_embeddings.py#L38

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants