-
Notifications
You must be signed in to change notification settings - Fork 39
/
linear_attn.py
153 lines (129 loc) · 6.09 KB
/
linear_attn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# -*- coding: utf-8 -*-
from typing import Optional
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from fla.modules import RMSNorm
from fla.modules.feature_map import (DPFPFeatureMap, HadamardFeatureMap,
HedgehogFeatureMap, T2RFeatureMap)
from fla.ops.linear_attn import (chunk_linear_attn, fused_chunk_linear_attn,
fused_recurrent_linear_attn)
class LinearAttention(nn.Module):
def __init__(
self,
mode: str = 'chunk',
hidden_size: str = 1024,
expand_k: int = 1.0,
expand_v: int = 1.0,
num_heads: int = 8,
num_kv_heads: Optional[int] = None,
feature_map: str = 'elementwise_product',
tie_feature_map_qk: bool = False,
output_norm: str = 'rmsnorm',
norm_q: bool = False,
norm_k: bool = False,
# standard linear attention normalization
do_feature_map_norm: bool = False,
elementwise_affine: bool = True,
norm_eps: float = 1e-5,
**kwargs
):
super().__init__()
self.hidden_size = hidden_size
self.mode = mode
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.key_dim_per_group = self.key_dim // self.num_kv_groups
self.value_dim_per_group = self.value_dim // self.num_kv_groups
assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.do_feature_map_norm = do_feature_map_norm
if feature_map == 'hedgehog':
if tie_feature_map_qk:
self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
else:
self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 't2r':
if tie_feature_map_qk:
self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
else:
self.feature_map_q = T2RFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 'elementwise_product':
if tie_feature_map_qk:
self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
else:
self.feature_map_q = HadamardFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 'dpfp':
self.feature_map_q = DPFPFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = DPFPFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 'elu':
def elu(x):
return F.elu(x) + 1
self.feature_map_q = elu
self.feature_map_k = elu
elif feature_map == 'relu':
self.feature_map_q = nn.ReLU()
self.feature_map_k = nn.ReLU()
elif feature_map == 'identity':
self.feature_map_q = nn.Identity()
self.feature_map_k = nn.Identity()
else:
raise NotImplementedError(f"Not supported feature map `{feature_map}`.")
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
if output_norm == 'rmsnorm':
self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
elif output_norm == 'identity':
self.norm = nn.Identity()
else:
raise NotImplementedError(f"Not supported output norm `{output_norm}`.")
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
self.norm_q = norm_q
self.norm_k = norm_k
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(self, x):
mode = self.mode
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = self.feature_map_q(q)
k = self.feature_map_k(k)
q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
if self.num_kv_groups > 1:
k, v = (repeat(x, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v))
else:
k, v = (rearrange(x, 'b n (h d) -> b h n d', h=self.num_kv_heads) for x in (k, v))
if self.norm_q:
q = q / (q.sum(-1, True) + 1e-4)
if self.norm_k:
k = k / (k.sum(-1, True) + 1e-4)
if mode == 'chunk':
o, final_state = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
elif mode == 'fused_chunk':
o, final_state = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
elif mode == 'fused_recurrent':
o, final_state = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
else:
raise NotImplementedError
o = self.norm(o)
o = rearrange(o, 'b h n d -> b n (h d)')
o = self.o_proj(o)
return o