-
Notifications
You must be signed in to change notification settings - Fork 972
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into deepspeed_main
- Loading branch information
Showing
10 changed files
with
372 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
# Based on: https://github.com/HazyResearch/flash-attention/blob/4a6eaa9f27df6fff7ffb2c24e894938a687dd870/flash_attn/flash_attn_interface.py | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
import flash_attn_cuda | ||
|
||
|
||
def _flash_attn_forward( | ||
q, | ||
k, | ||
v, | ||
out, | ||
cu_seqlens_q, | ||
cu_seqlens_k, | ||
max_seqlen_q, | ||
max_seqlen_k, | ||
dropout_p, | ||
softmax_scale, | ||
causal, | ||
return_softmax, | ||
num_splits=0, | ||
generator=None, | ||
): | ||
""" | ||
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means | ||
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. | ||
Don't change it unless you know what you're doing. | ||
""" | ||
softmax_lse, *rest = flash_attn_cuda.fwd( | ||
q, | ||
k, | ||
v, | ||
out, | ||
cu_seqlens_q, | ||
cu_seqlens_k, | ||
max_seqlen_q, | ||
max_seqlen_k, | ||
dropout_p, | ||
softmax_scale, | ||
False, | ||
causal, | ||
return_softmax, | ||
num_splits, | ||
generator, | ||
) | ||
# if out.isnan().any() or softmax_lse.isnan().any(): | ||
# breakpoint() | ||
S_dmask = rest[0] if return_softmax else None | ||
return out, softmax_lse, S_dmask | ||
|
||
|
||
def _flash_attn_backward( | ||
dout, | ||
q, | ||
k, | ||
v, | ||
out, | ||
softmax_lse, | ||
dq, | ||
dk, | ||
dv, | ||
cu_seqlens_q, | ||
cu_seqlens_k, | ||
max_seqlen_q, | ||
max_seqlen_k, | ||
dropout_p, | ||
softmax_scale, | ||
causal, | ||
num_splits=0, | ||
generator=None, | ||
): | ||
""" | ||
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or | ||
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic. | ||
Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel | ||
as num_splits=3), so effectively the choices are 0, 1, and 2. | ||
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine. | ||
""" | ||
_, _, _, softmax_d = flash_attn_cuda.bwd( | ||
dout, | ||
q, | ||
k, | ||
v, | ||
out, | ||
softmax_lse, | ||
dq, | ||
dk, | ||
dv, | ||
cu_seqlens_q, | ||
cu_seqlens_k, | ||
max_seqlen_q, | ||
max_seqlen_k, | ||
dropout_p, | ||
softmax_scale, | ||
False, | ||
causal, | ||
num_splits, | ||
generator, | ||
) | ||
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): | ||
# breakpoint() | ||
return dq, dk, dv, softmax_d | ||
|
||
|
||
class FlashAttnQKVPackedFunc(torch.autograd.Function): | ||
@staticmethod | ||
def forward( | ||
ctx, | ||
qkv, | ||
cu_seqlens, | ||
max_seqlen, | ||
dropout_p, | ||
softmax_scale, | ||
causal, | ||
return_softmax, | ||
): | ||
# Save rng_state because the backward pass will regenerate the dropout mask | ||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None | ||
if softmax_scale is None: | ||
softmax_scale = qkv.shape[-1] ** (-0.5) | ||
out, softmax_lse, S_dmask = _flash_attn_forward( | ||
qkv[:, 0], | ||
qkv[:, 1], | ||
qkv[:, 2], | ||
torch.empty_like(qkv[:, 0]), | ||
cu_seqlens, | ||
cu_seqlens, | ||
max_seqlen, | ||
max_seqlen, | ||
dropout_p, | ||
softmax_scale, | ||
causal=causal, | ||
return_softmax=return_softmax, | ||
) | ||
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state) | ||
ctx.dropout_p = dropout_p | ||
ctx.max_seqlen = max_seqlen | ||
ctx.softmax_scale = softmax_scale | ||
ctx.causal = causal | ||
return out if not return_softmax else (out, softmax_lse, S_dmask) | ||
|
||
@staticmethod | ||
def backward(ctx, dout, *args): | ||
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors | ||
if rng_state is not None: | ||
cur_rng_state = torch.cuda.get_rng_state() | ||
torch.cuda.set_rng_state(rng_state) | ||
dqkv = torch.empty_like(qkv) | ||
_flash_attn_backward( | ||
dout, | ||
qkv[:, 0], | ||
qkv[:, 1], | ||
qkv[:, 2], | ||
out, | ||
softmax_lse, | ||
dqkv[:, 0], | ||
dqkv[:, 1], | ||
dqkv[:, 2], | ||
cu_seqlens, | ||
cu_seqlens, | ||
ctx.max_seqlen, | ||
ctx.max_seqlen, | ||
ctx.dropout_p, | ||
ctx.softmax_scale, | ||
ctx.causal, | ||
) | ||
if rng_state is not None: | ||
torch.cuda.set_rng_state(cur_rng_state) | ||
return dqkv, None, None, None, None, None, None | ||
|
||
|
||
def flash_attn_unpadded_qkvpacked_func( | ||
qkv, | ||
cu_seqlens, | ||
max_seqlen, | ||
dropout_p, | ||
softmax_scale=None, | ||
causal=False, | ||
return_attn_probs=False, | ||
): | ||
return FlashAttnQKVPackedFunc.apply( | ||
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.