-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add flash_attn_kvpacked #862
Conversation
Thanks so much for the PR! This looks great and I'm testing it today. Though @satpalsr can you merge main into your branch and resolve the conflicts against the recent changes that add Flash Attention triton support? |
Signed-off-by: Dashiell Stander <[email protected]>
I find this commit is relevant to my issue (#883). |
@dashstander removed conflicts. Please test out now. |
Signed-off-by: Dashiell Stander <[email protected]>
megatron/model/transformer.py
Outdated
# Combined k/v into [b * sk, 2, np, hn]. | ||
kv = torch.concat([key_layer, value_layer], dim=1) | ||
|
||
output = self.flash_attn_unpadded_kvpacked_func( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs to be self.flash_attention_function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will run the inference and push any changes as needed.
Signed-off-by: Dashiell Stander <[email protected]>
Signed-off-by: Dashiell Stander <[email protected]>
There are a few outstanding bugs that I have pointed out, though the general gist is ensuring that this works for both training and inference and that the code switches between those modes correctly. Some of that is actually not an issue with this PR and seems to be an issue with the |
Signed-off-by: Dashiell Stander <[email protected]>
Just had a productive conversation with @satpalsr . Currently the implementation relies on the |
Where does this happen? I was expecting it to be in |
@StellaAthena It's an attribute built in to |
Edit: Nevermind, this is wrong. The Triton function we import has the Q / K / V matrices split up. |
Signed-off-by: Dashiell Stander <[email protected]>
Signed-off-by: Dashiell Stander <[email protected]>
I tested and confirmed this all worked with these changes. |
Signed-off-by: Dashiell Stander <[email protected]>
Signed-off-by: Dashiell Stander <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@satpalsr added me as a collaborator on their fork so I made the changes I was requesting. Looks good to me from here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dash has tested and approved this code and I don't see anything that stands out as problematic. You should be good to go ahead and merge it.
Thanks @dashstander for completing this. |
Issue
Inference with flash attention gives error due to improper shapes of qkv, as key and value are updated due to layer past
To Reproduce:
python deepy.py generate.py configs/70M-deduped.yml -i input_prompt.txt -o prompt_out.txt
with some text in input_prompt.txtThis PR separates query from packed qkv matrix to resolve the issue.