Skip to content

Commit

Permalink
accept changes from main & resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
satpalsr committed Apr 15, 2023
1 parent 45d7052 commit 9c645dd
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 105 deletions.
10 changes: 9 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 142b4b6
Default = ce9bee3

current git hash of repository

Expand Down Expand Up @@ -1951,6 +1951,14 @@ Args for deepspeed runner (deepspeed.launcher.runner).



- **force_multi**: bool

Default = False

Force multi-node training even if only one node is specified.



- **detect_nvlink_pairs**: bool

Default = False
Expand Down
29 changes: 18 additions & 11 deletions megatron/model/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
import torch.nn as nn
import torch.nn.functional as F

from flash_attn import flash_attn_triton
import flash_attn_cuda


def _flash_attn_forward(
def flash_attn_unpadded_unpacked_func_triton(
q, k, v, bias=None, causal=False, softmax_scale=None
):
return flash_attn_triton.flash_attn_func(q, k, v, bias, causal, softmax_scale)


def _flash_attn_forward_cuda(
q,
k,
v,
Expand Down Expand Up @@ -51,7 +58,7 @@ def _flash_attn_forward(
return out, softmax_lse, S_dmask


def _flash_attn_backward(
def _flash_attn_backward_cuda(
dout,
q,
k,
Expand Down Expand Up @@ -120,7 +127,7 @@ def forward(
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
out, softmax_lse, S_dmask = _flash_attn_forward_cuda(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
Expand Down Expand Up @@ -148,7 +155,7 @@ def backward(ctx, dout, *args):
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = torch.empty_like(qkv)
_flash_attn_backward(
_flash_attn_backward_cuda(
dout,
qkv[:, 0],
qkv[:, 1],
Expand All @@ -171,7 +178,7 @@ def backward(ctx, dout, *args):
return dqkv, None, None, None, None, None, None


def flash_attn_unpadded_qkvpacked_func(
def flash_attn_unpadded_qkvpacked_func_cuda(
qkv,
cu_seqlens,
max_seqlen,
Expand Down Expand Up @@ -204,7 +211,7 @@ def forward(
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
out, softmax_lse, S_dmask = _flash_attn_forward_cuda(
q,
kv[:, 0],
kv[:, 1],
Expand Down Expand Up @@ -244,7 +251,7 @@ def backward(ctx, dout, *args):
torch.cuda.set_rng_state(rng_state)
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
_flash_attn_backward(
_flash_attn_backward_cuda(
dout,
q,
kv[:, 0],
Expand All @@ -267,7 +274,7 @@ def backward(ctx, dout, *args):
return dq, dkv, None, None, None, None, None, None, None, None


def flash_attn_unpadded_kvpacked_func(
def flash_attn_unpadded_kvpacked_func_cuda(
q,
kv,
cu_seqlens_q,
Expand Down Expand Up @@ -339,7 +346,7 @@ def forward(
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
out, softmax_lse, S_dmask = _flash_attn_forward_cuda(
q,
k,
v,
Expand Down Expand Up @@ -379,7 +386,7 @@ def backward(ctx, dout, *args):
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
_flash_attn_backward_cuda(
dout,
q,
k,
Expand All @@ -402,7 +409,7 @@ def backward(ctx, dout, *args):
return dq, dk, dv, None, None, None, None, None, None, None, None


def flash_attn_unpadded_func(
def flash_attn_unpadded_func_cuda(
q,
k,
v,
Expand Down
45 changes: 45 additions & 0 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,51 @@ def get_slopes_power_of_2(n):
]
)

def bias(self, seq_len_q, seq_len_k, device, dtype):
# [b, np, sq, sk]
# seq_len_q = x.shape[-2]
# seq_len_k = x.shape[-1]

# Initialize the AliBi matrix to match the first provided key length; grow it exponentially
# afterwards if longer inputs are provided. This is important for inference, where we will
# encounter progressively longer samples; it should have no effect at training time.
if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k:
a = self.cached_matrix
else:
target_seq_len = (
seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4
)
a = -torch.tril(
torch.arange(target_seq_len)
.view(target_seq_len, 1)
.repeat(1, target_seq_len)
+ torch.arange(0, -target_seq_len, -1)
)
a = a.to(device).to(dtype)
slopes = self.slopes.to(a.device).to(a.dtype)
a = a * slopes.view(self.slopes.shape[0], 1, 1)
self.cached_seq_len = target_seq_len
self.cached_matrix = a

# If the AliBi matrix is larger than the key length, clip it.
if self.cached_seq_len > seq_len_k:
a = self.cached_matrix[:, :seq_len_k, :seq_len_k]

if seq_len_q != seq_len_k:
# In the train case x has dimensionality [b, np, sq, sk] with sq == sk
# The number of query tokens is equal to the number of key tokens
# At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence)
# In this case we use the appropriate token index of the cache matrix.
# As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used
assert (
seq_len_q == 1
), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1"
a = a[:, seq_len_k - 1, :].view(
a.shape[0], 1, a.shape[2]
) # seq_len_k - 1 points to the last token index in the current inference batch.

return a

def forward(self, x):
# [b, np, sq, sk]
seq_len_q = x.shape[-2]
Expand Down
Loading

0 comments on commit 9c645dd

Please sign in to comment.