Skip to content

Commit

Permalink
it turns out flash attention in pytorch 2.0 is not handling causal co…
Browse files Browse the repository at this point in the history
…rrectly when key and query lengths differ. get around with it by manually constructing mask
  • Loading branch information
lucidrains committed Mar 24, 2023
1 parent 81715e7 commit 49ccd71
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 44 deletions.
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ model = BlockRecurrentTransformer(
xl_memories_layers = (5, 6), # which layers to use xl memories. very old deepmind papers have shown you only need the last penultimate layers to have cached key values to see majority of benefit
num_state_vectors = 512, # number of state vectors, i believe this was a single block size in the paper, but can be any amount
recurrent_layers = (4,), # where to place the recurrent layer(s) for states with fixed simple gating
enhanced_recurrence = True # enhanced recurrence from ernie-doc paper, i have seen it to work well on my local machine
enhanced_recurrence = True, # enhanced recurrence from ernie-doc paper, i have seen it to work well on my local machine
use_flash_attn = True # use flash attention, if on pytorch 2.0
)

seq = torch.randint(0, 2000, (1, 1024))
Expand Down Expand Up @@ -62,9 +63,9 @@ $ python train.py
- [x] test full system on enwik8 locally and ablate states and memories and see effects first hand
- [x] make sure attention allow for single head key / values too
- [x] run a few experiments of fixed gating in regular transformers - does not work
- [x] integrate <a href="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/hazyresearch/flash-attention">flash attention</a>

- [ ] revisit <a href="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/lucidrains/memformer">memformer</a>
- [ ] integrate <a href="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/hazyresearch/flash-attention">flash attention</a>
- [ ] add ability to gate in memorizing transformers knn attention layers

## Citations
Expand Down Expand Up @@ -112,5 +113,22 @@ $ python train.py
}
```

```bibtex
@inproceedings{Sun2022ALT,
title = {A Length-Extrapolatable Transformer},
author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
year = {2022}
}
```

```bibtex
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
```


*Memory is Attention through Time* - Alex Graves
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import math
from random import random
from functools import wraps, partial
from collections import namedtuple
from packaging import version

from einops import rearrange

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat
from einops import rearrange, repeat, pack, unpack

from beartype import beartype
from beartype.door import is_bearable
Expand All @@ -29,12 +33,42 @@ def inner(self, *args, **kwargs):
return out
return inner

def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner

def compact(arr):
return [*filter(exists, arr)]

def and_reduce(arr: List[torch.Tensor]):
if len(arr) == 0:
return None
head, *rest = arr
for t in rest:
head = head & t
return head

print_once = once(print)

def divisible_by(numer, denom):
return (numer % denom) == 0

def l2norm(t):
return F.normalize(t, dim = -1)

def pack_one(t, pattern):
return pack([t], pattern)

def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]

def pad_at_dim(t, pad, dim = -1, value = 0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
Expand Down Expand Up @@ -130,6 +164,143 @@ def apply_rotary_pos_emb(t, pos, scale = 1.):

return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)

# maybe flash attention, if using pytorch 2.0

# constants

Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# main class

class Attend(nn.Module):
def __init__(
self,
causal = False,
use_flash_attn = False
):
super().__init__()
self.causal = causal
self.register_buffer("mask", None, persistent=False)

self.use_flash_attn = use_flash_attn
assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

# determine efficient attention configs for cuda and cpu

self.cpu_config = Config(True, True, True)
self.cuda_config = None

if not torch.cuda.is_available() or not use_flash_attn:
return

device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = Config(False, True, True)

def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]

mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask

def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])

if k.ndim == 3:
k = repeat(k, 'b ... -> b h ...', h = q.shape[1])

if v.ndim == 3:
v = repeat(v, 'b ... -> b h ...', h = q.shape[1])

# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L

masks = []

if self.causal:
i, j = q_len, k_len
causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)
masks.append(~causal_mask)

if exists(mask):
if mask.ndim != 2:
mask = repeat(mask, 'w ... -> (b w) ...', b = q.shape[0] // mask.shape[0])

masks.append(mask)

attn_mask = and_reduce(masks)

# Check if there is a compatible device for flash attention

config = self.cuda_config if is_cuda else self.cpu_config

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = attn_mask
)

return out

def forward(self, q, k, v, mask = None, use_flash_attn = None):
use_flash_attn = default(use_flash_attn, self.use_flash_attn)

b, n, device = q.shape[0], q.shape[-2], q.device

q, ps = pack_one(q, '* h n d')
k, _ = pack_one(k, '* n d')
v, _ = pack_one(v, '* n d')

if use_flash_attn:
out = self.flash_attn(q, k, v, mask = mask)
return unpack_one(out, ps, '* h n d')

scale = q.shape[-1] ** -0.5

k_einsum = 'b j d' if k.ndim == 3 else 'b h j d'
v_einsum = 'b j d' if v.ndim == 3 else 'b h j d'

# similarity

sim = einsum(f"b h i d, {k_einsum} -> b h i j", q, k) * scale

# key padding mask

if exists(mask):
if mask.ndim != 2:
mask = repeat(mask, 'w ... -> (b w) ...', b = b)

sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

# causal mask

if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

# attention

attn = sim.softmax(dim=-1)

# aggregate values

out = einsum(f"b h i j, {v_einsum} -> b h i d", attn, v)

return unpack_one(out, ps, '* h n d')

# geglu feedforward

class GEGLU(nn.Module):
Expand All @@ -155,15 +326,15 @@ def __init__(
causal = True,
qk_rmsnorm = False,
qk_rmsnorm_scale = 8,
single_head_kv = True
use_flash_attn = False
):
super().__init__()
self.causal = causal

self.qk_rmsnorm = qk_rmsnorm
self.qk_rmsnorm_scale = qk_rmsnorm_scale

self.single_head_kv = single_head_kv
self.attend = Attend(causal = causal, use_flash_attn = use_flash_attn)

if qk_rmsnorm:
self.q_scale = nn.Parameter(torch.ones(dim_head))
Expand Down Expand Up @@ -193,26 +364,9 @@ def forward(
q = apply_rotary_pos_emb(q, rotary_pos_emb, xpos_scale)
k = apply_rotary_pos_emb(k, rotary_pos_emb, xpos_scale)

# similarity

kv_einsum = '... j d' if self.single_head_kv else '... h j d'

sim = einsum(f'... h i d, {kv_einsum} -> ... h i j', q, k) * scale
# attention


max_neg_value = -torch.finfo(sim.dtype).max

if exists(mask):
sim = sim.masked_fill(~mask, max_neg_value)

if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), device = q.device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)

attn = sim.softmax(dim = -1)

out = einsum(f'... h i j, {kv_einsum} -> ... h i d', attn, v)
out = self.attend(q, k, v, mask = mask)

return out

Expand All @@ -227,7 +381,7 @@ def __init__(
qk_rmsnorm = False,
qk_rmsnorm_scale = 8,
num_state_vectors = 0,
single_head_kv = True
use_flash_attn = False
):
super().__init__()
inner_dim = dim_head * heads
Expand All @@ -237,10 +391,9 @@ def __init__(

self.to_q = nn.Linear(dim, inner_dim, bias = False)

self.single_head_kv = single_head_kv
self.to_kv = nn.Linear(dim, (inner_dim if not single_head_kv else dim_head) * 2, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)

self.attn = Attention(dim_head, causal = causal, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, single_head_kv = single_head_kv)
self.attn = Attention(dim_head, causal = causal, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)

self.block_width = block_width
self.is_recurrent_layer = num_state_vectors > 0
Expand All @@ -254,17 +407,17 @@ def __init__(
self.q_from_state = nn.Linear(dim, inner_dim, bias = False)

self.state_to_q = nn.Linear(dim, inner_dim, bias = False)
self.state_to_kv = nn.Linear(dim, (inner_dim if not single_head_kv else dim_head) * 2, bias = False)
self.state_to_kv = nn.Linear(dim, dim_head * 2, bias = False)

self.init_state = nn.Parameter(torch.randn(num_state_vectors, dim))
self.state_pos_ids = nn.Parameter(torch.randn(num_state_vectors, dim))

self.to_state_out = nn.Linear(inner_dim * 2, dim, bias = False)

self.to_state_cross_attn = Attention(dim_head, causal = False, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, single_head_kv = single_head_kv)
self.to_state_cross_attn = Attention(dim_head, causal = False, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)

self.state_self_attn = Attention(dim_head, causal = False, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, single_head_kv = single_head_kv)
self.from_state_cross_attn = Attention(dim_head, causal = False, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, single_head_kv = single_head_kv)
self.state_self_attn = Attention(dim_head, causal = False, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)
self.from_state_cross_attn = Attention(dim_head, causal = False, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)

# gating related parameters - using the fixed simple config

Expand Down Expand Up @@ -301,9 +454,6 @@ def forward(
split_head = partial(rearrange, pattern = 'b n (h d) -> b h n d', h = self.heads)
q = split_head(q)

if not self.single_head_kv:
k, v = map(split_head, (k, v))

# bucket the queries, keys, values by block width

bq, bk, bv = map(lambda t: rearrange(t, 'b ... (w n) d -> b w ... n d', n = width), (q, k, v))
Expand Down Expand Up @@ -384,9 +534,6 @@ def forward(
state_q, state_k, state_v = (self.state_to_q(self.init_state), *self.state_to_kv(self.init_state).chunk(2, dim = -1))
state_q = repeat(state_q, 'n (h d) -> b h n d', h = self.heads, b = batch)

if not self.single_head_kv:
state_k, state_v = map(lambda t: repeat(t, 'n (h d) -> b h n d', h = self.heads, b = batch), (state_k, state_v))

# cross attend to the past states key values

to_state_out = self.to_state_cross_attn(q_to_state, state_k, state_v)
Expand Down Expand Up @@ -434,7 +581,6 @@ def __init__(
depth,
dim_head = 64,
heads = 8,
single_head_kv = True,
all_layers_qk_rmsnorm = False,
ff_mult = 4,
max_seq_len = 1024,
Expand All @@ -444,7 +590,8 @@ def __init__(
num_state_vectors = None,
enhanced_recurrence = False,
ignore_index = -100,
rotary_use_xpos = True
rotary_use_xpos = True,
use_flash_attn = False
):
super().__init__()
num_state_vectors = default(num_state_vectors, block_width)
Expand Down Expand Up @@ -484,9 +631,9 @@ def __init__(
block_width = block_width,
dim_head = dim_head,
heads = heads,
single_head_kv = single_head_kv,
qk_rmsnorm = qk_rmsnorm,
num_state_vectors = layer_num_state_vectors
num_state_vectors = layer_num_state_vectors,
use_flash_attn = use_flash_attn
),
FeedForward(dim, mult = ff_mult)
]))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'block-recurrent-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.19',
version = '0.1.0',
license='MIT',
description = 'Block Recurrent Transformer - Pytorch',
author = 'Phil Wang',
Expand Down
Loading

0 comments on commit 49ccd71

Please sign in to comment.