Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed benchmarking #878

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f4706e0
add flash_attn_kvpacked
satpalsr Mar 29, 2023
f4a9106
Changed is_pipe_parallel setting to fix pipeline-parallel inference
curt-tigges Mar 31, 2023
83a7b9a
Update NeoXArgs docs automatically
invalid-email-address Mar 31, 2023
45d7052
fix formatting
satpalsr Apr 11, 2023
857c556
gpt benchmark script
cr458 Apr 3, 2023
1ab5bf3
remove duplicate argparse
cr458 Apr 4, 2023
afb6b29
HF inference
cr458 Apr 4, 2023
3f7d605
benchmarking configs + script changes
cr458 Apr 11, 2023
d99d2ce
plot directly, runs deepspeed and hf for single benchmark
cr458 Apr 12, 2023
b0e9745
remove plotting comments
cr458 Apr 12, 2023
9c645dd
accept changes from main & resolve conflicts
satpalsr Apr 15, 2023
ee99945
Merge branch 'main' into flash_attn_infer
satpalsr Apr 15, 2023
9b1733e
tmp changes
cr458 Apr 17, 2023
22cac56
Merge remote-tracking branch 'satpalsr/flash_attn_infer' into deepspe…
cr458 Apr 17, 2023
466749b
merge conflict git hash
cr458 Apr 17, 2023
b10739f
separate scripts for Deepspeed/HF and neox
cr458 Apr 18, 2023
4990f9b
debugging: works when world size > 1 but not otherwise
cr458 Apr 18, 2023
88981b2
working ( but not serially)
cr458 Apr 19, 2023
5e3ca7f
working ish gpt-neox just need to figure out how to get dataframe back
cr458 Apr 20, 2023
3ee9d3b
get dataframe output from stdout
cr458 Apr 20, 2023
2a6e8cd
remove gpt neox inference from script
cr458 May 21, 2023
7ea22d9
remove lines
cr458 May 21, 2023
ef4fdd4
device error
cr458 May 21, 2023
d8184f3
Add DS inference
satpalsr May 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add flash_attn_kvpacked
  • Loading branch information
satpalsr committed Mar 29, 2023
commit f4706e004ab53a53ac4aba554ca5fdf530542070
272 changes: 272 additions & 0 deletions megatron/model/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,275 @@ def flash_attn_unpadded_qkvpacked_func(
return FlashAttnQKVPackedFunc.apply(
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
)


class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
q,
kv[:, 0],
kv[:, 1],
torch.empty_like(q),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax,
)
ctx.save_for_backward(
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)

@staticmethod
def backward(ctx, dout, *args):
(
q,
kv,
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
rng_state,
) = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
_flash_attn_backward(
dout,
q,
kv[:, 0],
kv[:, 1],
out,
softmax_lse,
dq,
dkv[:, 0],
dkv[:, 1],
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None, None, None, None, None


def flash_attn_unpadded_kvpacked_func(
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnKVPackedFunc.apply(
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_attn_probs,
)


class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
q,
k,
v,
torch.empty_like(q),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax,
)
ctx.save_for_backward(
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)

@staticmethod
def backward(ctx, dout, *args):
(
q,
k,
v,
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
rng_state,
) = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dk, dv, None, None, None, None, None, None, None, None


def flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_attn_probs,
)
Loading