Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GraniteRMSNorm #33177

Merged
merged 2 commits into from
Sep 2, 2024
Merged

Add GraniteRMSNorm #33177

merged 2 commits into from
Sep 2, 2024

Conversation

NielsRogge
Copy link
Contributor

What does this PR do?

This PR is a follow-up of #31502 which broke Transformers for PyTorch < 2.4.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@mayank31398
Copy link
Contributor

Hi this makes sense.
I had added the nn.RMSNorm class
but looks like there is some issue with fp16 as well: pytorch/pytorch#134106

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Aug 29, 2024

Should I run slow tests or can this be merged as-is?

@mayank31398
Copy link
Contributor

might be better to run slow tests on the granite class @NielsRogge

@NielsRogge
Copy link
Contributor Author

Seems like the slow tests are failing (cc @ydshieh), but I assume it's safe to merge this PR since the following passes

from torch import nn
import torch


class GraniteRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        GraniteRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
    

a = torch.nn.RMSNorm(10)

b = GraniteRMSNorm(10)

assert a.weight.shape == b.weight.shape

c = torch.randn(1, 10)

assert torch.allclose(a(c), b(c))

@iamsaurabhgupt
Copy link

when will this release to fix transformers?
creating lot of conflicts as pytorch==2.4 is not supported by many packages like vllm, flash_attn, etc.

@mayank31398
Copy link
Contributor

@ArthurZucker can you merge this?
I will create a new PR for fixing the slow tests.
Looks like its blocking the DeepSpeed team.

ArthurZucker referenced this pull request Sep 2, 2024
* first commit

* drop tokenizer

* drop tokenizer

* drop tokenizer

* drop convert

* granite

* drop tokenization test

* mup

* fix

* reformat

* reformat

* reformat

* fix docs

* stop checking for checkpoint

* update support

* attention multiplier

* update model

* tiny drop

* saibo drop

* skip test

* fix test

* fix test

* drop

* drop useless imports

* update docs

* drop flash function

* copied from

* drop pretraining tp

* drop pretraining tp

* drop pretraining tp

* drop unused import

* drop code path

* change name

* softmax scale

* head dim

* drop legacy cache

* rename params

* cleanup

* fix copies

* comments

* add back legacy cache

* multipliers

* multipliers

* multipliers

* text fix

* fix copies

* merge

* multipliers

* attention multiplier

* drop unused imports

* fix

* fix

* fix

* move rope?

* Update src/transformers/models/granite/configuration_granite.py

Co-authored-by: Arthur <[email protected]>

* fix

* Update src/transformers/models/granite/modeling_granite.py

Co-authored-by: Arthur <[email protected]>

* fix

* fix

* fix

* fix

* fix-copies

* torch rmsnorm

* add authors

* change model path

* fix

* test

* drop static cache test

* uupdate readme

* drop non-causal

* readme

* drop useless imports

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <[email protected]>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <[email protected]>

* Update docs/source/en/model_doc/granite.md

Co-authored-by: Arthur <[email protected]>

---------

Co-authored-by: Arthur <[email protected]>
@ArthurZucker
Copy link
Collaborator

Okay

@ArthurZucker ArthurZucker merged commit b9bc691 into huggingface:main Sep 2, 2024
22 of 24 checks passed
@ArthurZucker
Copy link
Collaborator

cc @ydshieh if you see failures on this, it's expected!

@mayank31398
Copy link
Contributor

Thanks @ArthurZucker , Ill fix this test in a new PR

itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
* Add GraniteRMSNorm

* [run_slow] granite
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants