Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed benchmarking #878

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f4706e0
add flash_attn_kvpacked
satpalsr Mar 29, 2023
f4a9106
Changed is_pipe_parallel setting to fix pipeline-parallel inference
curt-tigges Mar 31, 2023
83a7b9a
Update NeoXArgs docs automatically
invalid-email-address Mar 31, 2023
45d7052
fix formatting
satpalsr Apr 11, 2023
857c556
gpt benchmark script
cr458 Apr 3, 2023
1ab5bf3
remove duplicate argparse
cr458 Apr 4, 2023
afb6b29
HF inference
cr458 Apr 4, 2023
3f7d605
benchmarking configs + script changes
cr458 Apr 11, 2023
d99d2ce
plot directly, runs deepspeed and hf for single benchmark
cr458 Apr 12, 2023
b0e9745
remove plotting comments
cr458 Apr 12, 2023
9c645dd
accept changes from main & resolve conflicts
satpalsr Apr 15, 2023
ee99945
Merge branch 'main' into flash_attn_infer
satpalsr Apr 15, 2023
9b1733e
tmp changes
cr458 Apr 17, 2023
22cac56
Merge remote-tracking branch 'satpalsr/flash_attn_infer' into deepspe…
cr458 Apr 17, 2023
466749b
merge conflict git hash
cr458 Apr 17, 2023
b10739f
separate scripts for Deepspeed/HF and neox
cr458 Apr 18, 2023
4990f9b
debugging: works when world size > 1 but not otherwise
cr458 Apr 18, 2023
88981b2
working ( but not serially)
cr458 Apr 19, 2023
5e3ca7f
working ish gpt-neox just need to figure out how to get dataframe back
cr458 Apr 20, 2023
3ee9d3b
get dataframe output from stdout
cr458 Apr 20, 2023
2a6e8cd
remove gpt neox inference from script
cr458 May 21, 2023
7ea22d9
remove lines
cr458 May 21, 2023
ef4fdd4
device error
cr458 May 21, 2023
d8184f3
Add DS inference
satpalsr May 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
accept changes from main & resolve conflicts
  • Loading branch information
satpalsr committed Apr 15, 2023
commit 9c645dd8629eb0f83719d902d3380539cdc3b4b1
10 changes: 9 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 142b4b6
Default = ce9bee3

current git hash of repository

Expand Down Expand Up @@ -1951,6 +1951,14 @@ Args for deepspeed runner (deepspeed.launcher.runner).



- **force_multi**: bool

Default = False

Force multi-node training even if only one node is specified.



- **detect_nvlink_pairs**: bool

Default = False
Expand Down
29 changes: 18 additions & 11 deletions megatron/model/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
import torch.nn as nn
import torch.nn.functional as F

from flash_attn import flash_attn_triton
import flash_attn_cuda


def _flash_attn_forward(
def flash_attn_unpadded_unpacked_func_triton(
q, k, v, bias=None, causal=False, softmax_scale=None
):
return flash_attn_triton.flash_attn_func(q, k, v, bias, causal, softmax_scale)


def _flash_attn_forward_cuda(
q,
k,
v,
Expand Down Expand Up @@ -51,7 +58,7 @@ def _flash_attn_forward(
return out, softmax_lse, S_dmask


def _flash_attn_backward(
def _flash_attn_backward_cuda(
dout,
q,
k,
Expand Down Expand Up @@ -120,7 +127,7 @@ def forward(
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
out, softmax_lse, S_dmask = _flash_attn_forward_cuda(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
Expand Down Expand Up @@ -148,7 +155,7 @@ def backward(ctx, dout, *args):
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = torch.empty_like(qkv)
_flash_attn_backward(
_flash_attn_backward_cuda(
dout,
qkv[:, 0],
qkv[:, 1],
Expand All @@ -171,7 +178,7 @@ def backward(ctx, dout, *args):
return dqkv, None, None, None, None, None, None


def flash_attn_unpadded_qkvpacked_func(
def flash_attn_unpadded_qkvpacked_func_cuda(
qkv,
cu_seqlens,
max_seqlen,
Expand Down Expand Up @@ -204,7 +211,7 @@ def forward(
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
out, softmax_lse, S_dmask = _flash_attn_forward_cuda(
q,
kv[:, 0],
kv[:, 1],
Expand Down Expand Up @@ -244,7 +251,7 @@ def backward(ctx, dout, *args):
torch.cuda.set_rng_state(rng_state)
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
_flash_attn_backward(
_flash_attn_backward_cuda(
dout,
q,
kv[:, 0],
Expand All @@ -267,7 +274,7 @@ def backward(ctx, dout, *args):
return dq, dkv, None, None, None, None, None, None, None, None


def flash_attn_unpadded_kvpacked_func(
def flash_attn_unpadded_kvpacked_func_cuda(
q,
kv,
cu_seqlens_q,
Expand Down Expand Up @@ -339,7 +346,7 @@ def forward(
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
out, softmax_lse, S_dmask = _flash_attn_forward_cuda(
q,
k,
v,
Expand Down Expand Up @@ -379,7 +386,7 @@ def backward(ctx, dout, *args):
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
_flash_attn_backward_cuda(
dout,
q,
k,
Expand All @@ -402,7 +409,7 @@ def backward(ctx, dout, *args):
return dq, dk, dv, None, None, None, None, None, None, None, None


def flash_attn_unpadded_func(
def flash_attn_unpadded_func_cuda(
q,
k,
v,
Expand Down
45 changes: 45 additions & 0 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,51 @@ def get_slopes_power_of_2(n):
]
)

def bias(self, seq_len_q, seq_len_k, device, dtype):
# [b, np, sq, sk]
# seq_len_q = x.shape[-2]
# seq_len_k = x.shape[-1]

# Initialize the AliBi matrix to match the first provided key length; grow it exponentially
# afterwards if longer inputs are provided. This is important for inference, where we will
# encounter progressively longer samples; it should have no effect at training time.
if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k:
a = self.cached_matrix
else:
target_seq_len = (
seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4
)
a = -torch.tril(
torch.arange(target_seq_len)
.view(target_seq_len, 1)
.repeat(1, target_seq_len)
+ torch.arange(0, -target_seq_len, -1)
)
a = a.to(device).to(dtype)
slopes = self.slopes.to(a.device).to(a.dtype)
a = a * slopes.view(self.slopes.shape[0], 1, 1)
self.cached_seq_len = target_seq_len
self.cached_matrix = a

# If the AliBi matrix is larger than the key length, clip it.
if self.cached_seq_len > seq_len_k:
a = self.cached_matrix[:, :seq_len_k, :seq_len_k]

if seq_len_q != seq_len_k:
# In the train case x has dimensionality [b, np, sq, sk] with sq == sk
# The number of query tokens is equal to the number of key tokens
# At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence)
# In this case we use the appropriate token index of the cache matrix.
# As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used
assert (
seq_len_q == 1
), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1"
a = a[:, seq_len_k - 1, :].view(
a.shape[0], 1, a.shape[2]
) # seq_len_k - 1 points to the last token index in the current inference batch.

return a

def forward(self, x):
# [b, np, sq, sk]
seq_len_q = x.shape[-2]
Expand Down
Loading