-
Notifications
You must be signed in to change notification settings - Fork 8.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix flash attention for ROCm #7011
base: master
Are you sure you want to change the base?
Conversation
I didn't close that other PR on accident. As I said before, I don't think we should be adding a dependency with rocWMMA when the performance is no better than master and we have no dev to test and support it. And I will do an implementation of FlashAttention without any tensor cores at all which may end up being faster anyways. |
I don't know how to get compile on windows :( |
Sorry, I didn't realize it had been closed on purpose. Is the dependency that bad, though? rocwmma is header only, so no link time requirement, and it enables sharing the existing CUDA code. The performance is not better, but the VRAM saving can be very significant, 1 GB in one case. The PR is not ready to merge as is anyway, I need to disable flash-attn in CMake by default for AMD GPUs, or enable it only if rocwmma is detected installed. I might not be a ROCm expert, but I am a C++ dev and I own a 7900xtx, if not merged, I might maintain this fork anyway. Of course, if you already have planned to work on that other implementation soon, all of this comment is irrelevant, but having access to a rocwmma based version as a comparison could be useful, I don't know. Please let me know what you think. |
I wasn't able to test flash attention on Windows with 7900XTX yet. |
So i can say that for CDNA this makes a big difference: This pr:
Lastest Master:
Both of those are still terrible compared to exllama but this pr dose make a big difference in the right direction Note that i had to make some trivial changes to this pr to make it choose the wmma path for gfx908 |
Id like to mention it here too, that after some optimization work to the gemm kernels (#8082) this pr now improves pp performance on CDNA by almost 2x and i really think the stance towards this pr needs to be revised. A tiny optional header only dependency is for sure worth a 2x or even 10% increase in speed and the fact that the cuda equivalent depedancy is fine but the rocm equivalent is not speaks volumes, as dose the comment on rocm perfomance here: #7716. |
My original plan was to buy an AMD GPU with tensor cores so that I can test and maintain these changes myself (currently I only have an RX 6800). But I currently have issues finding a slot for it in one of my machines. However, if I can get a pledge from you that you will help with maintenance I would be fine with merging a PR like this. Keep in mind though that the WMMA FlashAttention kernels that I wrote for CUDA are bad in the first place. They rely on the "high-level" WMMA functionality to use tensor cores but after talking to an NVIDIA engineer and doing some related work myself the better way to utilize tensor cores is via PTX instructions (CUDA equivalent of assembly). So I want to at some point rewrite the code accordingly. Instead of rocWMMA it would be much better to implement the equivalent AMD functionality in |
i cant accept maintainership of llamacpp/hip. I can promise to run regular testing (automated even if desired) on cdna. The current state of affairs also strongly discourage any optimization effort on my and others part, as even if you do some work optimize the hip back end, and even if you manage to get that merged, the nvidia centric churn in the common code base invetiably breaks performance again, usually only shortly later. also note that gfx11's wmma and gfx908/a/4x's mfma are very different with totally different hw implementation performance characteristics. |
When I make changes to the CUDA code I test it for correctness and performance using my RX 6800. My standard for numerical software is that correct results are the first and foremost priority. I very much do not want to have any broken code in my repositories. So if I cannot test or assert that the code produces correct results myself and if I also cannot delegate this to anyone else then I am simply not willing to merge the corresponding PR. The simple rocWMMA prototype that I did still required fixes from other people to work at all. My current stance towards HIP performance is that I am willing to invest some effort for AMD support "within reason". When it comes to MMQ in particular the performance depends very heavily on the exact data layout and for good AMD performance you would have to completely re-write the code anyways. |
llama-bench
buffer = ROCm0 compute buffer size