Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4D attention_mask support #27539

Merged
merged 14 commits into from
Dec 17, 2023
Merged

4D attention_mask support #27539

merged 14 commits into from
Dec 17, 2023

Conversation

poedator
Copy link
Contributor

@poedator poedator commented Nov 16, 2023

This is implementation for feature request from #27493 custom 4d attention_mask as transformers .forward() argument.

  1. Allowing 4d attention masks to pass thru _prepare_4d_causal_attention_mask() intact
  2. support in OPT (need to build custom positions tensor)
  3. support in Llama (while Llama can accept custom position_ids, I added code to generate them internally)

The benefits of the code are to enable more memory-efficient text generation with tree-based parallel decoding as described in SpecInfer paper

Tagging:
@gante (generate)
@patrickvonplaten (masks)
@younesbelkada @ArthurZucker (text models)

This PR is WiP:

  • Will add tests
  • Need advice on how to handle models beyond covered Llama and OPT
  • May add example for memory-efficient generation

IMPORTANT: this PR makes changes that can only used by few classes of models
requirements to use:

  • have position_ids argument in .forward() method
  • use modeling_attn_mask_utils.py::_prepare_4d_attention_mask() function for 4d mask generation

as of 20.12.2023, only a handful (under 20) of transformers model classes meet these criteria. Most of these classes are multimodal, which may require their own use cases for 4D masks. The pure language modelling classes fit to use the 4D mask changes from this PR are only LlamaModel, FalconModel and XGLMModel.

@poedator poedator marked this pull request as ready for review November 16, 2023 19:12
@patrickvonplaten
Copy link
Contributor

Generally, I don't have a problem with allowing to pass 4D attention masks! @poedator can you explain your use case a little bit for why you want to pass 4d attention masks?

@poedator
Copy link
Contributor Author

poedator commented Nov 20, 2023

@patrickvonplaten
here is a use example:
Suppose one does beam search and has a starting prefix with tokens 11 22 33 in 4 beams. Now he needs to check candidates with tokens 44, 55, 66, and 77. Present code would pack the beams into a batch of shape (4, 4):

11 22 33  44
11 22 33  55
11 22 33  66
11 22 33  77

and run it with mask of all ones, passing such mask in 2D which gets expanded internally to 4D.

The proposed way would be to have a batch shaped (1, 7):
11 22 33 44 55 66 77
and the 4d mask would have a shape (1, 1, 4, 7) and look like this:

1  1  1  1  0  0  0 
1  1  1  0  1  0  0 
1  1  1  0  0  1  0 
1  1  1  0  0  0  1

with a positions tensor of [0, 1, 2, 3, 3, 3, 3]

At subsequent beam search iterations the mask will reflect which past tokens should the new tokens attend to.
Such mask needs to pass intact.
This saves memory for past_key_values cache and thus allows beam search and other similar inference (like SpecInfer) of longer sequences with limited VRAM.

Another use case is kindly proposed by @UniverseFly below.

@UniverseFly
Copy link

Very interesting PR! Would this feature also enable SFT packing as mentioned in huggingface/trl#805?

@poedator
Copy link
Contributor Author

Very interesting PR! Would this feature also enable SFT packing as mentioned in huggingface/trl#805?
Sure it would. Just have a separate packing function somewhere - it is beyond the scope of this PR.
Besides, one should be able to pack multiple series of sequences into a batch this way.

@UniverseFly
Copy link

I tried this branch and the model.forward seems to work fairly well, but model.generate raises errors with the 4D attention mask (with Llama). After some checking, it might be due to the missing logic here:

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs

@poedator
Copy link
Contributor Author

Generate looks like a harder challenge for your methods - each individual sequence will be expanding, thus you'd need to reorder past_kv and mask at each step. I believe that to implement it, you'd need to write custom prepare_inputs_for_generation(), and possibly some more logic.
I'll be happy to test drive it.
On my side I intend to write a PR for more efficient beam search after this PR merges.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally the PR looks good to me! (We'd need some tests here).

@ArthurZucker wdyt?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks alright! but there should not be changes to the forward of the models (IMO)

Comment on lines 872 to 877
if attention_mask is not None and len(attention_mask.shape) == 4:
# assumes 4D mask for efficient beam search
token_positions = torch.cumsum(attention_mask, dim=-1).amax(dim=(1, 2))
used_tokens_mask = attention_mask.amax(dim=(1, 2))
position_ids = (token_positions * used_tokens_mask).long() - 1
position_ids = position_ids[:, past_key_values_length:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic should not go here, it's should go in the prepare inputs for generation, as it's purely specific to 4d beam search.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, that this should be limited to just the mask code. Makes this PR more manageable. Llama can work without that, since it can accept position_ids argument. Hopefully the newer models will support this argument. (could HF make it a part of some model guidelines?)

@poedator
Copy link
Contributor Author

Hi, @ArthurZucker
I limited this PR only to the mask code, proceeding with the tests.

So far I have demo in Colab with monkey patch based on this PR. It shows a negligible difference in logits obtained the old and new ways. I dent to believe that this is a rounding error somewhere. Would you support it as the basis for the tests?
BTW, where to put this new test?

Hi, @UniverseFly ,
Try the monkey patch from the Colab notebook - see if it works to implement your idea.

@KexinFeng
Copy link

Thanks for this PR and the demo. It is very helpful in trying the SpecInfer paper. Also in another recent progress on speculative decoding look ahead decoding Fig 5, this PR will also be useful.

@ArthurZucker
Copy link
Collaborator

Reviewing now 😉

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, the test should go in the

class AttentionMaskTester(unittest.TestCase):

@poedator
Copy link
Contributor Author

poedator commented Dec 10, 2023

  • squashed all earlier commits into one
  • added tests. Made a separate class to test with full model loading.
  • added support for sdpa (following F.scaled_dot_product_attention support #26572)
  • test_modeling_utils.py::AttentionMaskTester and ::TestAttentionImplementation tests pass
  • new tests pass

@ArthurZucker, please review. Hopefully it is ready to merge.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just a few testing nits and good to go

tests/test_modeling_utils.py Outdated Show resolved Hide resolved
self.device = torch.device("cuda:0")
model_name = "JackFram/llama-160m" # small Llama-like model from FlexFlow
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(self.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(self.device)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(self.device)

the smaller the better for our CI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed that fp16 tests are more noisy, so what I did is:

  • retained fp32 testsm but used even smaller model
  • added fp16 test with relaxed tolerances
  • added fp16 testing option for the top tokens order.

tests/test_modeling_utils.py Show resolved Hide resolved
tests/test_modeling_utils.py Outdated Show resolved Hide resolved
tests/test_modeling_utils.py Outdated Show resolved Hide resolved
@poedator
Copy link
Contributor Author

@ArthurZucker, pls give me a hint about NameError: name 'torch' is not defined error. Apparently a decorator or import is missing, but can't figure it out. The import and decorators seem in place...

@poedator
Copy link
Contributor Author

poedator commented Dec 16, 2023

Sorry checking the test there are duplicate markers 😅 not sure they are needed no?

Earlier, I got frustrated with failing commits and added decorators everywhere. Now most of them are gone and it still passes CI checks.

@ArthurZucker ArthurZucker merged commit f85a1e8 into huggingface:main Dec 17, 2023
18 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks for the contribution! 🤗

@poedator
Copy link
Contributor Author

@ArthurZucker , would you want to publish a blog post in HF blog with 4d attention use cases?
I propose to include:

@ArthurZucker
Copy link
Collaborator

If you want feel free to do so! 🤗

@PhilJd
Copy link

PhilJd commented Dec 19, 2023

Note that not all paths of this can be torch.compiled:

The following fails due to torch.all(attention_mask == 1).

import torch
import torch.nn as nn
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa

class Model(nn.Module):
    def forward(self, inputs_embeds):
        batch_size, seq_length, _ = inputs_embeds.shape
        past_key_values_length = 10
        attention_mask = torch.tensor([1.])
        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )
        return attention_mask

model = Model()
model = torch.compile(model, fullgraph=True)
model(torch.ones([1,5, 32]))

@poedator
Copy link
Contributor Author

@PhilJd,
torch.all(attention_mask == 1) was present even before this PR.
see this line
it comes form #26572

have you tested the preceding commit?

@PhilJd
Copy link

PhilJd commented Dec 19, 2023

Ah sorry, just looked at the blame - yeah, the previous commit fails as well @fxmarty .

@shentianxiao
Copy link

_prepare_4d_causal_attention_mask is applied only if self._use_flash_attention_2 is False (https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1039). Is it because 4D attention mask and flash attention 2 are not compatible?

@shentianxiao
Copy link

The function description should be updated to avoid confusion as attention_mask is not necessarily 2D now (https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py#L290)

@poedator
Copy link
Contributor Author

poedator commented Dec 19, 2023

@shentianxiao , thank you for your attention to the 4D attention!

_prepare_4d_causal_attention_mask is applied only if self._use_flash_attention_2 is False (https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1039). Is it because 4D attention mask and flash attention 2 are not compatible?

it is not about compatibility, rather the flash_attention_2 code contrasted original mask vs modified mask coming from _prepare_4d_causal_attention_mask()

The function description should be updated to avoid confusion as attention_mask is not necessarily 2D now (https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py#L290)

I agree, that the original mask may also be 4d-shaped now. I just started PR #28151 with documentation updates - will make edits there. Hopefully the maintainers responsible for flash_attention_2 will verify it.

@poedator
Copy link
Contributor Author

IMPORTANT: this PR makes changes that can only used by few classes of models
requirements to use:

  • have position_ids argument in .forward() method
  • use modeling_attn_mask_utils.py::_prepare_4d_attention_mask() function for 4d mask generation

as of 20.12.2023, only a handful (under 20) of transformers model classes meet these criteria. Most of these classes are multimodal, which may require their own use cases for 4D masks. The pure language modelling classes fit to use the 4D mask changes from this PR are only LlamaModel, FalconModel and XGLMModel.

@poedator
Copy link
Contributor Author

I made a small blog post based on this PR.
https://huggingface.co/blog/poedator/4d-masks
Big thanks to everyone who contributed and commented!

staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
* edits to _prepare_4d_causal_attention_mask()

* initial tests for 4d mask

* attention_mask_for_sdpa support

* added test for inner model hidden

* added autotest decorators

* test mask dtype to torch.int64

* torch.testing.assert_close

Co-authored-by: Arthur <[email protected]>

* torch_device and @torch_gpu in tests

* upd tests

* +torch decorators

* torch decorators fixed

* more decorators!

* even more decorators

* fewer decorators

---------

Co-authored-by: Arthur <[email protected]>
@jpgard
Copy link

jpgard commented Feb 18, 2024

Thanks for the amazing addition!! This is a great new feature.

Just wanted to ask a question to make sure I am using it properly. In the code here, it looks like the 4D masks are expected to have shape [batch_size, 1, seq_len, seq_len]. (I am inferring that the 1 in the expected_shape is the heads dimension so that the same mask is broadcast to all heads.) In the blog post, it describes the attention masks as having shape [heads, batch_size, input_ids_length, total_sequence_length].

My question is: are the heads and batch_size dimensions transposed in the blog post? It seems like we are actually expected to provide 4D masks where the first axis is batch size, the second is heads. The blog post implies the reverse. Since I am sometimes using a batch size of 1 in testing, this works either way, but I want to use it correctly and don't see the "proper" shape documented anywhere (perhaps it is documented somewhere and I missed it!).

Thanks!

@poedator
Copy link
Contributor Author

@jpgard ,
you are correct, there was an error in my blog post.
Changed it to [batch_size, heads, input_ids_length, total_sequence_length]
thank you for raising this!

@jpgard
Copy link

jpgard commented Feb 19, 2024

Great, thanks for the quick reply and for your hard work on this @poedator !!

@jpgard
Copy link

jpgard commented Feb 19, 2024

Has this been tested with flash attention 2? Works great for me without flash attention 2, but when using flash attention I get lots of messages of the form ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [202,0,0], thread: [105,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.

Lower chunk of the stack trace posted below.

 File "/admin/home-jpgard/miniconda3/envs/rtfm/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 798, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/admin/home-jpgard/miniconda3/envs/rtfm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/admin/home-jpgard/miniconda3/envs/rtfm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/admin/home-jpgard/miniconda3/envs/rtfm/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 549, in forward
    attn_output = self._flash_attention_forward(
  File "/admin/home-jpgard/miniconda3/envs/rtfm/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 592, in _flash_attention_forward
    query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
  File "/admin/home-jpgard/miniconda3/envs/rtfm/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 631, in _upad_input
    query_layer = index_first_axis(
  File "/admin/home-jpgard/miniconda3/envs/rtfm/lib/python3.8/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/admin/home-jpgard/miniconda3/envs/rtfm/lib/python3.8/site-packages/flash_attn/bert_padding.py", line 17, in forward
    return torch.gather(
RuntimeError: CUDA error: device-side assert triggered

Would be great to be able to use FA2 with this PR as the speedups are much larger as sequence length grows -- so FA2 seems like the perfect accompaniment to e.g. "packed" training sequences enabled by this PR.

@poedator
Copy link
Contributor Author

@jpgard , please share some simple testing code. I will look into this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants