Skip to content

Commit

Permalink
align the inputs with flash attention
Browse files Browse the repository at this point in the history
Summary:
Each run will go through a range of N_CTX and we will have D_HEAD set via arg.
Also add an additional key for autotuning.

Reviewed By: xuzhao9, sijiac

Differential Revision: D57979174

fbshipit-source-id: e415acdea7f5f912574dc227a0e52ff13bf2b261
  • Loading branch information
manman-ren authored and facebook-github-bot committed Jun 4, 2024
1 parent 1b8725c commit bb7fe98
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchbenchmark/util/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _attn_fwd_inner(
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=4, num_warps=8),
],
key=["N_CTX"],
key=["N_CTX", "H"],
)
@triton.jit
def _attn_fwd(
Expand Down

0 comments on commit bb7fe98

Please sign in to comment.