Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sub-quadratic attention #1

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c810c32
initial commit of sub-quadratic attention source from https://github.…
Birch-san Dec 26, 2022
c9b3b9f
invoke efficient_dot_product_attention(). not currently giving correc…
Birch-san Dec 26, 2022
70dc50d
provide a way to skip checkpointing
Birch-san Dec 26, 2022
c794f0b
MPS fixes; now working
Birch-san Dec 26, 2022
04a5cbe
eliminate all einsums. assume 3D tensor [batch * num_heads, tokens, c…
Birch-san Dec 26, 2022
b44fa12
remove the bits that I broke in the pursuit of speed (mask, bias, wei…
Birch-san Dec 26, 2022
8694703
clarify comment; verified that upcast_attention is indeed still helpf…
Birch-san Dec 26, 2022
5bfe96d
add TODO about softmax
Birch-san Dec 26, 2022
da8901b
typings
Birch-san Dec 26, 2022
0c4d82f
simplify protocols
Birch-san Dec 26, 2022
c5e8e31
remove unused
Birch-san Dec 26, 2022
b16edc9
simplify protocol
Birch-san Dec 26, 2022
b7fc3a8
fix tensor shape destructuring
Birch-san Dec 26, 2022
8f003c2
simplify dynamic_slice
Birch-san Dec 26, 2022
1334670
simplify chunk scanning
Birch-san Dec 26, 2022
0676c13
inline sole use of map_pt function
Birch-san Dec 26, 2022
264dfb7
simplify
Birch-san Dec 26, 2022
205f55b
no longer using original utilities from memory-efficient-attention re…
Birch-san Dec 26, 2022
1880c0e
fix query slicing
Birch-san Dec 26, 2022
8603c30
fix kv chunking
Birch-san Dec 26, 2022
96e0d8c
simplify dynamic slicing
Birch-san Dec 26, 2022
63ca66d
removed bias, mask, weights, calc_fn, and the conditions controlling …
Birch-san Dec 26, 2022
f4c0bf4
device arg fix no longer included
Birch-san Dec 26, 2022
624123f
simplify
Birch-san Dec 26, 2022
5b92dab
clarify attributions now that algorithm has been substantially rewritten
Birch-san Dec 26, 2022
60f0a5e
add chunk_threshold_bytes to let you specify your safe memory limit, …
Birch-san Dec 28, 2022
48db711
fast path for when we're just attention-slicing (i.e. chunking query …
Birch-san Dec 28, 2022
ef20fb9
default kv_chunk_size was meant to be sqrt() of global key size, not …
Birch-san Dec 28, 2022
69a8d2e
remove debug notes
Birch-san Dec 28, 2022
db25934
explain kv fast-path
Birch-san Dec 28, 2022
7aa8bac
add fast-path for "1 query chunk"
Birch-san Dec 28, 2022
59002c3
move kv_chunk_size_min concern to callsite, since if caller knows fin…
Birch-san Dec 28, 2022
a3152d8
Revert "move kv_chunk_size_min concern to callsite (1c4f10748e31d1851…
Birch-san Dec 28, 2022
0eafb95
de-duplicate fast-path for "matmul < quota". we can just ask for ever…
Birch-san Dec 28, 2022
9dc6822
pre-transpose key, rather than transposing it then undoing the transp…
Birch-san Dec 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
move kv_chunk_size_min concern to callsite, since if caller knows fin…
…al kv_chunk_size: they can notice when no chunking would happen at all, and use fast-path. note: there's a question of whether that concern belongs *inside* the algorithm. but it'd feel weird for chunked attention to have a no-chunking-at-all branch.
  • Loading branch information
Birch-san committed Dec 30, 2022
commit 59002c33af66d561f9c844ed4ae047b1e1d1910f
12 changes: 9 additions & 3 deletions src/diffusers/models/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn.functional as F
from torch import nn, Tensor
import math

from ..utils.import_utils import is_xformers_available

Expand Down Expand Up @@ -318,14 +319,19 @@ def __call__(
_, k_tokens, _ = key.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens

if self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes:
kv_chunk_size = min(self.kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
if self.kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, self.kv_chunk_size_min)

uses_chunking = q_tokens > self.query_chunk_size or k_tokens > kv_chunk_size

if uses_chunking and (self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes):
hidden_states = efficient_dot_product_attention(
query,
key,
value,
query_chunk_size=self.query_chunk_size,
kv_chunk_size=self.kv_chunk_size,
kv_chunk_size_min=self.kv_chunk_size_min,
kv_chunk_size=kv_chunk_size,
use_checkpoint=attn.training,
)
else:
Expand Down
6 changes: 0 additions & 6 deletions src/diffusers/models/sub_quadratic_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def efficient_dot_product_attention(
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
):
"""Computes efficient dot-product attention given query, key, and value.
Expand All @@ -140,7 +139,6 @@ def efficient_dot_product_attention(
`[batch * num_heads, tokens, channels_per_head]`.
query_chunk_size: int: query chunks size
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
Returns:
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
Expand All @@ -149,10 +147,6 @@ def efficient_dot_product_attention(
_, k_tokens, _ = key.shape
scale = q_channels_per_head ** -0.5

kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)

def get_query_chunk(chunk_idx: int) -> Tensor:
return dynamic_slice(
query,
Expand Down