-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GraniteRMSNorm #33177
Add GraniteRMSNorm #33177
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hi this makes sense. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
Should I run slow tests or can this be merged as-is? |
might be better to run slow tests on the granite class @NielsRogge |
Seems like the slow tests are failing (cc @ydshieh), but I assume it's safe to merge this PR since the following passes from torch import nn
import torch
class GraniteRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
GraniteRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
a = torch.nn.RMSNorm(10)
b = GraniteRMSNorm(10)
assert a.weight.shape == b.weight.shape
c = torch.randn(1, 10)
assert torch.allclose(a(c), b(c)) |
when will this release to fix transformers? |
@ArthurZucker can you merge this? |
* first commit * drop tokenizer * drop tokenizer * drop tokenizer * drop convert * granite * drop tokenization test * mup * fix * reformat * reformat * reformat * fix docs * stop checking for checkpoint * update support * attention multiplier * update model * tiny drop * saibo drop * skip test * fix test * fix test * drop * drop useless imports * update docs * drop flash function * copied from * drop pretraining tp * drop pretraining tp * drop pretraining tp * drop unused import * drop code path * change name * softmax scale * head dim * drop legacy cache * rename params * cleanup * fix copies * comments * add back legacy cache * multipliers * multipliers * multipliers * text fix * fix copies * merge * multipliers * attention multiplier * drop unused imports * fix * fix * fix * move rope? * Update src/transformers/models/granite/configuration_granite.py Co-authored-by: Arthur <[email protected]> * fix * Update src/transformers/models/granite/modeling_granite.py Co-authored-by: Arthur <[email protected]> * fix * fix * fix * fix * fix-copies * torch rmsnorm * add authors * change model path * fix * test * drop static cache test * uupdate readme * drop non-causal * readme * drop useless imports * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <[email protected]> * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <[email protected]> * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <[email protected]> --------- Co-authored-by: Arthur <[email protected]>
Okay |
cc @ydshieh if you see failures on this, it's expected! |
Thanks @ArthurZucker , Ill fix this test in a new PR |
* Add GraniteRMSNorm * [run_slow] granite
What does this PR do?
This PR is a follow-up of #31502 which broke Transformers for PyTorch < 2.4.