Skip to content

Commit

Permalink
Switch to using Cuda Flash Attn for Alibi (#1183)
Browse files Browse the repository at this point in the history
* add cuda support for flash attn w/ alibi, warn of deprecation of triton

* Update NeoXArgs docs automatically

---------

Co-authored-by: github-actions <[email protected]>
  • Loading branch information
haileyschoelkopf and github-actions committed Mar 13, 2024
1 parent 6809bbc commit 03186de
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 28 deletions.
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 33f2842
Default = fdac107

current git hash of repository

Expand Down
63 changes: 40 additions & 23 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from pkg_resources import packaging
from importlib.metadata import version

from .norms import get_norm
from megatron import mpu
Expand Down Expand Up @@ -412,6 +414,14 @@ def __init__(
self.rope_fusion = neox_args.rope_fusion
self.attention_type = neox_args.attention_config[layer_number]
self.use_flash_attention = self.attention_type == "flash"
self.use_triton = (
self.use_flash_attention
and self.pos_emb == "alibi"
and (
not packaging.version.Version(version("flash-attn"))
>= packaging.version.Version("2.4.0.post1")
)
)
self.sparse = self.attention_type not in ("global", "flash")

if self.gqa:
Expand Down Expand Up @@ -578,7 +588,7 @@ def flash_attention(self, query_layer, key_layer, value_layer):
key_layer.size(0),
)

if self.pos_emb != "alibi":
if self.use_flash_attention and not self.use_triton:

# [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn]
key_layer = key_layer.transpose(0, 1).reshape(
Expand All @@ -588,41 +598,46 @@ def flash_attention(self, query_layer, key_layer, value_layer):
output_size[0], output_size[3], self.num_kv_heads_per_partition, -1
)

batch_size = output_size[0]
max_seqlen_q = output_size[2]
max_seqlen_k = output_size[3]

cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device,
)

cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * max_seqlen_k,
step=max_seqlen_k,
dtype=torch.int32,
device=key_layer.device,
)

# [sq, b, np, hn] -> [b, sq, np, hn]
query_layer = query_layer.transpose(0, 1).reshape(
output_size[0], output_size[2], output_size[1], -1
)

# only pass in window_size kwarg to flash-attn
# if we use Sliding Window Attention.
# only pass in window_size or alibi_slopes kwarg
# if we use Sliding Window Attention / AliBi.
# Flash attn defaults to (-1,-1), or
# does not have this kwarg prior to v2.3.0
extra_kwargs = (
{"window_size": (self.sliding_window_width, -1)}
if self.sliding_window_width is not None
else {}
)
if self.pos_emb == "alibi":
extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to(
query_layer.device
).to(torch.float32)

if not self.training:
batch_size = output_size[0]
max_seqlen_q = output_size[2]
max_seqlen_k = output_size[3]

cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device,
)

cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * max_seqlen_k,
step=max_seqlen_k,
dtype=torch.int32,
device=key_layer.device,
)

q_shape = query_layer.shape
k_shape = key_layer.shape
v_shape = value_layer.shape
Expand Down Expand Up @@ -662,6 +677,8 @@ def flash_attention(self, query_layer, key_layer, value_layer):
matmul_result = matmul_result.transpose(1, 2)

else:
# we still use Triton if using AliBi with flash-attn<2.4.0.post1.

# [sq, b, np, hn] -> [b, sq, np, hn]
sq = query_layer.size(0)
b = query_layer.size(1)
Expand Down
14 changes: 10 additions & 4 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,11 +1095,17 @@ def calculate_derived(self):
self.num_kv_heads % self.model_parallel_size == 0
), "Number of KV heads must be at least model_parallel_size for now!"
# Flash attention version >=2.3.0 required to combine Flash + Sliding Window Attention
if self.sliding_window_width is not None and "flash" in self.attention_config:
if "flash" in self.attention_config:
_flash_version = packaging.version.Version(version("flash-attn"))
assert _flash_version >= packaging.version.Version(
"2.3.0"
), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention."
if self.sliding_window_width is not None:
assert _flash_version >= packaging.version.Version(
"2.3.0"
), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention."
if self.pos_emb == "alibi":
if not _flash_version >= packaging.version.Version("2.4.0.post1"):
print(
f"Warning: Flash-Attention version ({str(_flash_version)}) must be >= 2.4.0.post1 to support AliBi. Falling back to flash-attn triton backend, but version 2.4.0.post1 or later will be required in future."
)

# Adding equal dataset weights if none are provided
if self.train_data_paths and (self.train_data_weights is None):
Expand Down

0 comments on commit 03186de

Please sign in to comment.