From bb7fe989e1941fa223e9627c7072cdbea15355c8 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Tue, 4 Jun 2024 15:29:49 -0700 Subject: [PATCH] align the inputs with flash attention Summary: Each run will go through a range of N_CTX and we will have D_HEAD set via arg. Also add an additional key for autotuning. Reviewed By: xuzhao9, sijiac Differential Revision: D57979174 fbshipit-source-id: e415acdea7f5f912574dc227a0e52ff13bf2b261 --- torchbenchmark/util/kernels/triton_fused_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchbenchmark/util/kernels/triton_fused_attention.py b/torchbenchmark/util/kernels/triton_fused_attention.py index a50ef555de..f17c628b44 100644 --- a/torchbenchmark/util/kernels/triton_fused_attention.py +++ b/torchbenchmark/util/kernels/triton_fused_attention.py @@ -112,7 +112,7 @@ def _attn_fwd_inner( triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=4, num_warps=8), ], - key=["N_CTX"], + key=["N_CTX", "H"], ) @triton.jit def _attn_fwd(