forked from mistralai/mistral-inference
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cache.py
213 lines (175 loc) · 8.02 KB
/
cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import torch
from typing import List, Tuple
from dataclasses import dataclass
from xformers.ops.fmha.attn_bias import (
AttentionBias,
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalMask,
)
@dataclass
class RotatingCacheInputMetadata:
# rope absolute positions
positions: torch.Tensor
# which elements in the sequences need to be cached
to_cache_mask: torch.Tensor
# how many elements are cached per sequence
cached_elements: torch.Tensor
# where tokens should go in the cache
cache_positions: torch.Tensor
# if prefill, use block diagonal causal mask
# else use causal with padded key mask
prefill: bool
mask: AttentionBias
seqlens: List[int]
def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]):
assert len(l1) == len(l2)
return [v for pair in zip(l1, l2) for v in pair]
def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor:
assert cache.ndim == 3 # (W, H, D)
position = seqlen % cache.shape[0]
if seqlen < cache.shape[0]:
return cache[:seqlen]
elif position == 0:
return cache
else:
return torch.cat([cache[position:], cache[:position]], dim=0)
class CacheView:
def __init__(self, cache_k: torch.Tensor, cache_v: torch.Tensor, metadata: RotatingCacheInputMetadata, kv_seqlens: torch.Tensor):
self.cache_k = cache_k
self.cache_v = cache_v
self.kv_seqlens = kv_seqlens
self.metadata = metadata
def update(self, xk: torch.Tensor, xv: torch.Tensor):
"""
to_cache_mask masks the last [sliding_window] tokens in each sequence
"""
n_kv_heads, head_dim = self.cache_k.shape[-2:]
flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim)
flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim)
flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask])
flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask])
def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This is a naive implementation and not optimized for speed.
"""
assert xk.ndim == xv.ndim == 3 # (B * T, H, D)
assert xk.shape == xv.shape
if all([s == 0 for s in self.metadata.seqlens]):
# No cache to interleave
return xk, xv
# Make it a list of [(T, H, D)]
xk = torch.split(xk, self.metadata.seqlens)
xv = torch.split(xv, self.metadata.seqlens)
assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}"
# Order elements in cache by position by unrotating
cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)]
cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)]
interleaved_k = interleave_list(cache_k, xk)
interleaved_v = interleave_list(cache_v, xv)
return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0)
@property
def sliding_window(self):
return self.cache_k.shape[1]
@property
def key(self) -> torch.Tensor:
return self.cache_k[:len(self.kv_seqlens)]
@property
def value(self) -> torch.Tensor:
return self.cache_v[:len(self.kv_seqlens)]
@property
def prefill(self):
return self.metadata.prefill
@property
def mask(self):
return self.metadata.mask
class RotatingBufferCache:
"""
This is an example that implements a less naive rotating buffer cache, allowing for variable length sequences.
Allocated cache is rectangular which is wasteful (see PagedAttention for better mechanisms)
"""
def __init__(self, n_layers: int, max_batch_size: int, sliding_window: int, n_kv_heads: int, head_dim: int):
self.sliding_window = sliding_window
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.cache_k = torch.empty((
n_layers,
max_batch_size,
sliding_window,
n_kv_heads,
head_dim
))
self.cache_v = torch.empty((
n_layers,
max_batch_size,
sliding_window,
n_kv_heads,
head_dim
))
# holds the valid length for each batch element in the cache
self.kv_seqlens = None
def get_view(self, layer_id: int, metadata: RotatingCacheInputMetadata) -> CacheView:
return CacheView(self.cache_k[layer_id], self.cache_v[layer_id], metadata, self.kv_seqlens)
def reset(self):
self.kv_seqlens = None
def init_kvseqlens(self, batch_size: int):
self.kv_seqlens = torch.zeros((batch_size,), device=self.device, dtype=torch.long)
@property
def device(self):
return self.cache_k.device
def to(self, device: torch.device, dtype: torch.dtype):
self.cache_k = self.cache_k.to(device=device, dtype=dtype)
self.cache_v = self.cache_v.to(device=device, dtype=dtype)
return self
def update_seqlens(self, seqlens: List[int]):
self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long)
def get_input_metadata(self, seqlens: List[int]) -> RotatingCacheInputMetadata:
"""
inpput = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3
--> only cache last 3 tokens in each sequence
- to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1]
- cached_elements = [3 | 3 | 2]
--> absolute positions are used for rope
- positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4]
--> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window
- cache_positions = [2 0 1 | 5 3 4 | 6 7]
"""
if self.kv_seqlens is None:
self.init_kvseqlens(len(seqlens))
assert len(seqlens) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?"
seqpos = self.kv_seqlens.tolist()
assert len(seqlens) > 0, seqlens
masks = [
[x >= seqlen - self.sliding_window for x in range(seqlen)]
for seqlen in seqlens
]
to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool)
cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long)
positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(device=self.device, dtype=torch.long)
batch_idx = torch.tensor(sum([[i]*seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long)
cache_positions = positions % self.sliding_window + batch_idx * self.sliding_window
first_prefill = seqpos[0] == 0
subsequent_prefill = any(seqlen > 1 for seqlen in seqlens)
if first_prefill:
assert all([pos == 0 for pos in seqpos]), (seqpos)
mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.sliding_window)
elif subsequent_prefill:
mask = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens,
kv_seqlen=[s + cached_s.clamp(max=self.sliding_window).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)]
).make_local_attention_from_bottomright(self.sliding_window)
else:
mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=seqlens,
kv_padding=self.sliding_window,
kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.sliding_window).tolist()
)
return RotatingCacheInputMetadata(
positions=positions,
to_cache_mask=to_cache_mask,
cached_elements=cached_elements,
cache_positions=cache_positions[to_cache_mask],
prefill=first_prefill or subsequent_prefill,
mask=mask,
seqlens=seqlens,
)