-
Notifications
You must be signed in to change notification settings - Fork 39
/
based.py
105 lines (89 loc) · 4.09 KB
/
based.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
"""
Linear attention in Based.
https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
"""
import torch
import torch.nn as nn
from einops import rearrange
from fla.modules.feature_map import TaylorFeatureMap
from fla.ops.based import parallel_based
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
class BasedLinearAttention(nn.Module):
def __init__(
self,
hidden_size: int,
feature_dim: int = 16,
num_key_value_heads: int = 12,
num_heads: int = 12,
feature_name: str = "taylor_exp",
eps: float = 1e-12,
causal: bool = True,
mode: str = "parallel",
):
super().__init__()
self.hidden_size = hidden_size
self.mode = mode
self.feature_name = feature_name
self.feature_dim = feature_dim
self.num_key_value_heads = num_key_value_heads
self.num_heads = num_heads
self.head_dim = self.hidden_size // self.num_key_value_heads
self.causal = causal
self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.dropout = nn.Identity()
self.feature_map = TaylorFeatureMap(feature_dim)
self.eps = eps
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(self, hidden_states: torch.Tensor, **kwargs):
mode = self.mode
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v])
if mode == "fused_chunk":
q, k = self.feature_map(q), self.feature_map(k)
o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1)
elif mode == 'chunk':
q, k = self.feature_map(q), self.feature_map(k)
o = chunk_linear_attn(q, k, v, normalize=True, scale=1)
elif mode == 'parallel':
assert q.shape[-1] <= 128
o = parallel_based(q, k, v, True, True)
o = rearrange(o, "b h l d -> b l (h d)")
o = self.o_proj(o)
o = self.dropout(o)
return o
# https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
"""
x (torch.Tensor): tensor of shape (b, d, l)
y (torch.Tensor): tensor of shape (b, d, l)
"""
# hidden_states = hidden_states.transpose(1, 2)
b, l, _ = hidden_states.size()
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2)
k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Linear attention
q, k = self.feature_map(q), self.feature_map(k)
q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
# Compute attention
if self.causal:
y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
else:
y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
y = rearrange(y, 'b h l d -> b l (h d)')
y = self.o_proj(y.to(hidden_states.dtype))
y = self.dropout(y)
return y.to(hidden_states.dtype)