Skip to content

Commit

Permalink
ALibi & Flash Attention (#864)
Browse files Browse the repository at this point in the history
* Code formatting

Signed-off-by: Dashiell Stander <[email protected]>

* Import it differently

Signed-off-by: Dashiell Stander <[email protected]>

* Bump up triton dependency

Signed-off-by: Dashiell Stander <[email protected]>

* No kwargs

Signed-off-by: Dashiell Stander <[email protected]>

* No kwargs

Signed-off-by: Dashiell Stander <[email protected]>

* Get the signature right

Signed-off-by: Dashiell Stander <[email protected]>

* Get the signature right

Signed-off-by: Dashiell Stander <[email protected]>

* Add dim for num heads to bias

Signed-off-by: Dashiell Stander <[email protected]>

* Add dim for num heads to bias

Signed-off-by: Dashiell Stander <[email protected]>

* Add dim for num heads to bias

Signed-off-by: Dashiell Stander <[email protected]>

* Think I have the shape right now?

Signed-off-by: Dashiell Stander <[email protected]>

* blegh shapes

Signed-off-by: Dashiell Stander <[email protected]>

* blegh shapes

Signed-off-by: Dashiell Stander <[email protected]>

* Need to get the triton version just right

Signed-off-by: Dashiell Stander <[email protected]>

* Remove debug print statements

Signed-off-by: Dashiell Stander <[email protected]>

* Need to permute the dimensions before returning

Signed-off-by: Dashiell Stander <[email protected]>

* Update SparseAttention signature

Signed-off-by: Dashiell Stander <[email protected]>

* Clean up code.

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

---------

Signed-off-by: Dashiell Stander <[email protected]>
Co-authored-by: github-actions <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
3 people committed Apr 11, 2023
1 parent 038b011 commit f3d65b5
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 55 deletions.
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = c6c1be7
Default = cb6042d

current git hash of repository

Expand Down
17 changes: 12 additions & 5 deletions megatron/model/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import torch.nn as nn
import torch.nn.functional as F

from flash_attn import flash_attn_triton
import flash_attn_cuda


def _flash_attn_forward(
def _flash_attn_forward_cuda(
q,
k,
v,
Expand Down Expand Up @@ -51,7 +52,7 @@ def _flash_attn_forward(
return out, softmax_lse, S_dmask


def _flash_attn_backward(
def _flash_attn_backward_cuda(
dout,
q,
k,
Expand Down Expand Up @@ -120,7 +121,7 @@ def forward(
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
out, softmax_lse, S_dmask = _flash_attn_forward_cuda(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
Expand Down Expand Up @@ -148,7 +149,7 @@ def backward(ctx, dout, *args):
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = torch.empty_like(qkv)
_flash_attn_backward(
_flash_attn_backward_cuda(
dout,
qkv[:, 0],
qkv[:, 1],
Expand All @@ -171,7 +172,7 @@ def backward(ctx, dout, *args):
return dqkv, None, None, None, None, None, None


def flash_attn_unpadded_qkvpacked_func(
def flash_attn_unpadded_qkvpacked_func_cuda(
qkv,
cu_seqlens,
max_seqlen,
Expand All @@ -183,3 +184,9 @@ def flash_attn_unpadded_qkvpacked_func(
return FlashAttnQKVPackedFunc.apply(
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
)


def flash_attn_unpadded_qkvpacked_func_triton(
q, k, v, bias=None, causal=False, softmax_scale=None
):
return flash_attn_triton.flash_attn_func(q, k, v, bias, causal, softmax_scale)
45 changes: 45 additions & 0 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,51 @@ def get_slopes_power_of_2(n):
]
)

def bias(self, seq_len_q, seq_len_k, device, dtype):
# [b, np, sq, sk]
# seq_len_q = x.shape[-2]
# seq_len_k = x.shape[-1]

# Initialize the AliBi matrix to match the first provided key length; grow it exponentially
# afterwards if longer inputs are provided. This is important for inference, where we will
# encounter progressively longer samples; it should have no effect at training time.
if self.cached_seq_len is not None and self.cached_seq_len >= seq_len_k:
a = self.cached_matrix
else:
target_seq_len = (
seq_len_k if self.cached_seq_len is None else self.cached_seq_len * 4
)
a = -torch.tril(
torch.arange(target_seq_len)
.view(target_seq_len, 1)
.repeat(1, target_seq_len)
+ torch.arange(0, -target_seq_len, -1)
)
a = a.to(device).to(dtype)
slopes = self.slopes.to(a.device).to(a.dtype)
a = a * slopes.view(self.slopes.shape[0], 1, 1)
self.cached_seq_len = target_seq_len
self.cached_matrix = a

# If the AliBi matrix is larger than the key length, clip it.
if self.cached_seq_len > seq_len_k:
a = self.cached_matrix[:, :seq_len_k, :seq_len_k]

if seq_len_q != seq_len_k:
# In the train case x has dimensionality [b, np, sq, sk] with sq == sk
# The number of query tokens is equal to the number of key tokens
# At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence)
# In this case we use the appropriate token index of the cache matrix.
# As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used
assert (
seq_len_q == 1
), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1"
a = a[:, seq_len_k - 1, :].view(
a.shape[0], 1, a.shape[2]
) # seq_len_k - 1 points to the last token index in the current inference batch.

return a

def forward(self, x):
# [b, np, sq, sk]
seq_len_q = x.shape[-2]
Expand Down
109 changes: 61 additions & 48 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def __init__(
self.attention_type = neox_args.attention_config[layer_number]
self.use_flash_attention = self.attention_type == "flash"
self.sparse = self.attention_type not in ("global", "flash")
self.sparse = self.attention_type != "global" and not self.use_flash_attention
if self.sparse:
self.sparse_attn = configure_sparse_attention(
neox_args,
Expand All @@ -278,18 +277,18 @@ def __init__(
else:
if self.use_flash_attention:
from megatron.model.flash_attention import (
flash_attn_unpadded_qkvpacked_func,
flash_attn_unpadded_qkvpacked_func_cuda,
flash_attn_unpadded_qkvpacked_func_triton,
)

self.flash_attention_function = flash_attn_unpadded_qkvpacked_func
if self.pos_emb == "alibi":
raise ValueError(
"Flash attention is currently not compatible with AliBi positional embeddings. Use sinuisoidal, learned, or rotary embeddings instead."
self.flash_attention_function = (
flash_attn_unpadded_qkvpacked_func_triton
)
else:
self.flash_attention_function = (
flash_attn_unpadded_qkvpacked_func_cuda
)
from megatron.model.flash_attention import (
flash_attn_unpadded_qkvpacked_func,
)

else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
Expand Down Expand Up @@ -428,47 +427,61 @@ def flash_attention(self, query_layer, key_layer, value_layer):
query_layer.size(0),
key_layer.size(0),
)
# [s, b, np, hn] -> [b, s, np, hn] -> [b * s, 1, np, hn]

query_layer = query_layer.transpose(0, 1).reshape(
output_size[0] * output_size[2], 1, output_size[1], -1
)
key_layer = key_layer.transpose(0, 1).reshape(
output_size[0] * output_size[3], 1, output_size[1], -1
)
value_layer = value_layer.transpose(0, 1).reshape(
output_size[0] * output_size[3], 1, output_size[1], -1
)

# Combined q/k/v into [b * s, 3, np, hn].
qkv = torch.concat([query_layer, key_layer, value_layer], dim=1)

batch_size = output_size[0]
seqlen = output_size[2]
max_s = seqlen

cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=qkv.device,
)
output = self.flash_attention_function(
qkv,
cu_seqlens,
max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=None,
causal=True,
)
# [b * sq, np, hn] -> [b, sq, np, hn]
matmul_result = output.view(
output_size[0], output_size[2], output.shape[1], output.shape[2]
)
# [b, sq, np, hn] -> [b, np, sq, hn]
matmul_result = matmul_result.transpose(1, 2)
if self.pos_emb != "alibi":
# [s, b, np, hn] -> [b, s, np, hn] -> [b * s, 1, np, hn]
query_layer = query_layer.transpose(0, 1).reshape(
output_size[0] * output_size[2], 1, output_size[1], -1
)
key_layer = key_layer.transpose(0, 1).reshape(
output_size[0] * output_size[3], 1, output_size[1], -1
)
value_layer = value_layer.transpose(0, 1).reshape(
output_size[0] * output_size[3], 1, output_size[1], -1
)
# Combined q/k/v into [b * s, 3, np, hn].
qkv = torch.concat([query_layer, key_layer, value_layer], dim=1)

batch_size = output_size[0]
seqlen = output_size[2]
max_s = seqlen

cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=qkv.device,
)

output = self.flash_attention_function(
qkv,
cu_seqlens,
max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=None,
causal=True,
)
# [b * sq, np, hn] -> [b, sq, np, hn]
matmul_result = output.view(
output_size[0], output_size[2], output.shape[1], output.shape[2]
)
# [b, sq, np, hn] -> [b, np, sq, hn]
matmul_result = matmul_result.transpose(1, 2)
else:
# [sq, b, np, hn] -> [b, sq, np, hn]
sq = query_layer.size(0)
b = query_layer.size(1)
sk = key_layer.size(0)
query_layer = query_layer.transpose(0, 1)
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
bias = self.alibi_embed.bias(sq, sk, query_layer.device, query_layer.dtype)
bias = bias.unsqueeze(0).tile((b, 1, 1, 1))
matmul_result = self.flash_attention_function(
query_layer, key_layer, value_layer, bias=bias, causal=True
)
matmul_result = matmul_result.transpose(1, 2)
return matmul_result

def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-sparseattention.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
triton==0.4.2
triton==2.0.0.dev20221202

0 comments on commit f3d65b5

Please sign in to comment.