Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flash_attn_kvpacked #862

Merged
merged 14 commits into from
Apr 21, 2023
Merged

Conversation

satpalsr
Copy link
Contributor

@satpalsr satpalsr commented Mar 29, 2023

Issue
Inference with flash attention gives error due to improper shapes of qkv, as key and value are updated due to layer past
To Reproduce:

  1. Take any config with flash attention.
  2. Run python deepy.py generate.py configs/70M-deduped.yml -i input_prompt.txt -o prompt_out.txt with some text in input_prompt.txt
  3. Gives error due to different sq and sk values here

This PR separates query from packed qkv matrix to resolve the issue.

@satpalsr satpalsr requested a review from a team as a code owner March 29, 2023 16:40
@dashstander dashstander self-assigned this Apr 13, 2023
@dashstander
Copy link
Contributor

dashstander commented Apr 13, 2023

Thanks so much for the PR! This looks great and I'm testing it today. Though @satpalsr can you merge main into your branch and resolve the conflicts against the recent changes that add Flash Attention triton support?

Signed-off-by: Dashiell Stander <[email protected]>
@DaoD
Copy link

DaoD commented Apr 14, 2023

I find this commit is relevant to my issue (#883).
The general idea to tackle the problem is the same.
I think we do not need to add a new parameter layer_past in the function of flash_attention, but simply use if self.training == True to check whether it is training or inference stage.
High five for the same solution!!!

@satpalsr
Copy link
Contributor Author

@dashstander removed conflicts. Please test out now.
Thanks for the suggestion @DaoD

Signed-off-by: Dashiell Stander <[email protected]>
# Combined k/v into [b * sk, 2, np, hn].
kv = torch.concat([key_layer, value_layer], dim=1)

output = self.flash_attn_unpadded_kvpacked_func(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be self.flash_attention_function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will run the inference and push any changes as needed.

Signed-off-by: Dashiell Stander <[email protected]>
Signed-off-by: Dashiell Stander <[email protected]>
@dashstander
Copy link
Contributor

There are a few outstanding bugs that I have pointed out, though the general gist is ensuring that this works for both training and inference and that the code switches between those modes correctly. Some of that is actually not an issue with this PR and seems to be an issue with the generate.py script, as was mention in the (now closed) #883. Actively working on this now, should be able to iron this stuff out.

@dashstander
Copy link
Contributor

Just had a productive conversation with @satpalsr . Currently the implementation relies on the torch.nn.Module attribute training which isn't used elsewhere in the code base. There's also an issue with the three different FlashAttention functions getting assigned to the same name at initialization. Since models are always initialized with self.training == True even resetting it by calling inference_mode won't change the FlashAttention function.

@StellaAthena
Copy link
Member

Since models are always initialized with self.training == True even resetting it by calling inference_mode won't change the FlashAttention function.

Where does this happen? I was expecting it to be in setup_for_inference_or_eval but it doesn't seem to be.

@StellaAthena StellaAthena linked an issue Apr 18, 2023 that may be closed by this pull request
@dashstander
Copy link
Contributor

dashstander commented Apr 18, 2023

Where does this happen? I was expecting it to be in setup_for_inference_or_eval but it doesn't seem to be.

@StellaAthena It's an attribute built in to torch.nn.Module and just set to True by default. Part of the changes if we want to use this particular attribute will be setting it in the train_mode and inference_mode methods (here and here). An alternative might be to specifically add an inference vs training configuration to NeoxArgs, but that would definitely need to be part of a larger discussion.

@dashstander
Copy link
Contributor

dashstander commented Apr 18, 2023

I also just realized that this doesn't add the Triton kvpacked function , so technically inference with Triton FlashAttention + ALiBi would be broken.

Edit: Nevermind, this is wrong. The Triton function we import has the Q / K / V matrices split up.

Signed-off-by: Dashiell Stander <[email protected]>
Signed-off-by: Dashiell Stander <[email protected]>
@dashstander
Copy link
Contributor

I tested and confirmed this all worked with these changes.

Signed-off-by: Dashiell Stander <[email protected]>
Signed-off-by: Dashiell Stander <[email protected]>
Copy link
Contributor

@dashstander dashstander left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@satpalsr added me as a collaborator on their fork so I made the changes I was requesting. Looks good to me from here.

megatron/model/transformer.py Outdated Show resolved Hide resolved
megatron/model/transformer.py Outdated Show resolved Hide resolved
Copy link
Member

@StellaAthena StellaAthena left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dash has tested and approved this code and I don't see anything that stands out as problematic. You should be good to go ahead and merge it.

@StellaAthena StellaAthena merged commit c64bacc into EleutherAI:main Apr 21, 2023
@satpalsr
Copy link
Contributor Author

Thanks @dashstander for completing this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unable to run generate text
5 participants