Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RotaryEmbedding applied to the incorrect channel dimension #841

Open
sagadre opened this issue Aug 31, 2023 · 0 comments
Open

RotaryEmbedding applied to the incorrect channel dimension #841

sagadre opened this issue Aug 31, 2023 · 0 comments

Comments

@sagadre
Copy link

sagadre commented Aug 31, 2023

🐛 Bug

Input tensors to attention must be in format [B, M, H, K], where B is the batch size, M the sequence length, H the number of heads, and K the embedding size per head as documented here.

Hence positional embedding (e.g., rotary embedding) should be applied to dim=1. However, in the RotaryEmbedding class, dim=-2 is being passed, which corresponds to dim=2 as seen here.

def forward(
        self, q: torch.Tensor, k: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
            k, seq_dimension=-2 # should be seq_dimension=1 or no argument should be passed as the default value is correct
        )

        return (
            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
        )

Additional context

Thanks to @jmercat who found symptoms of this problem downstream of xformers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant