Skip to content

Commit

Permalink
add flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
honglu2875 committed Nov 11, 2023
1 parent 95f492d commit 1425287
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
26 changes: 19 additions & 7 deletions aria/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from torch import nn as nn
from torch.nn import functional as F
from aria.utils import is_flash_attn_2_available
from aria.model.dynamic_yarn import DynamicYaRNScaledRotaryEmbedding


Expand Down Expand Up @@ -168,6 +169,16 @@ def __init__(self, model_config: ModelConfig):
out_features=3 * model_config.d_model,
bias=False,
)

if is_flash_attn_2_available():
from flash_attn import flash_attn_func

self.attn_fn = flash_attn_func
self._use_flash_attn = True
else:
self.attn_fn = F.scaled_dot_product_attention
self._use_flash_attn = False

self.att_proj_linear = nn.Linear(
in_features=model_config.d_model,
out_features=model_config.d_model,
Expand Down Expand Up @@ -223,10 +234,11 @@ def _att_block(self, x: torch.Tensor, use_cache=False, past_kv=None):
xk = torch.concat([past_kv[0], xk], axis=1)
xv = torch.concat([past_kv[1], xv], axis=1)
kv = (xk, xv)
# Reshape for attention calculation: (b_sz, n_head, s_len, d_head)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
if not self._use_flash_attn:
# Reshape for attention calculation: (b_sz, n_head, s_len, d_head)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)

# Required as we are not using a nn.Dropout layer
if self.training:
Expand All @@ -236,8 +248,8 @@ def _att_block(self, x: torch.Tensor, use_cache=False, past_kv=None):

# Using beta torch functionality (subject to change)
# See - https://shorturl.at/jtI17
if past_kv is None:
att = F.scaled_dot_product_attention(
if past_kv is None or self._use_flash_attn:
att = self.attn_fn(
query=xq,
key=xk,
value=xv,
Expand All @@ -247,7 +259,7 @@ def _att_block(self, x: torch.Tensor, use_cache=False, past_kv=None):
else:
assert xq.size(2) == 1
mask = torch.ones(1, xk.size(2), dtype=bool, device=xk.device)
att = F.scaled_dot_product_attention(
att = self.attn_fn(
query=xq,
key=xk,
value=xv,
Expand Down
36 changes: 35 additions & 1 deletion aria/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Contains miscellaneous utilities"""

import importlib
import os
import subprocess
from typing import Union, Tuple
from packaging import version

import requests

from pydub import AudioSegment
Expand Down Expand Up @@ -55,3 +58,34 @@ def midi_to_audio(mid_path: str, soundfont_path: str | None = None):
print(e)

print(f"Saved files: \n{wav_path}\n{mp3_path}")


def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# --- Taken from transformers library ---
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
package_version = importlib.metadata.version(pkg_name)
package_exists = True
except importlib.metadata.PackageNotFoundError:
package_exists = False
if return_version:
return package_exists, package_version
else:
return package_exists


def is_flash_attn_2_available():
# --- Taken from transformers library ---
_flash_attn_2_available = _is_package_available("flash_attn") and version.parse(
importlib.metadata.version("flash_attn")
) >= version.parse("2.1.0")
_torch_available = _is_package_available("torch")
if not _torch_available:
return False

import torch

return _flash_attn_2_available and torch.cuda.is_available()

0 comments on commit 1425287

Please sign in to comment.