from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn import functional as F
from .functional.attention import kv_attention
[docs]class KeyValAttention(nn.Module):
""" Key-value attention.
:param scaling:
:param dropout_prob:
"""
def __init__(self,
scaling: bool = False,
dropout_prob: float = 0):
super(KeyValAttention, self).__init__()
self._scaling = scaling
self._dropout = dropout_prob
[docs] def forward(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
additive_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" See `functional.attention.kv_attention` for details
:param query:
:param key:
:param value:
:param mask:
:param additive_mask:
:return:
"""
return kv_attention(query, key, value, mask, additive_mask,
self.training, self._dropout, self._scaling)
[docs]class AttentionPool2d(nn.Module):
# from openAI clip
def __init__(self,
embed_dim: int,
num_heads: int):
super().__init__()
self.k_proj = nn.Parameter(torch.randn(embed_dim, embed_dim))
self.q_proj = nn.Parameter(torch.randn(embed_dim, embed_dim))
self.v_proj = nn.Parameter(torch.randn(embed_dim, embed_dim))
self.bias = nn.Parameter(torch.randn(3 * embed_dim))
self.c_proj = nn.Linear(embed_dim, embed_dim)
self.num_heads = num_heads
self.initialize_weights()
[docs] def forward(self,
x: torch.Tensor
) -> torch.Tensor:
x = x.flatten(-2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (1+HW)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj,
k_proj_weight=self.k_proj,
v_proj_weight=self.v_proj,
in_proj_weight=None,
in_proj_bias=self.bias,
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x[0]
[docs] def initialize_weights(self):
std = self.c_proj.in_features ** -0.5
nn.init.normal_(self.k_proj, std=std)
nn.init.normal_(self.q_proj, std=std)
nn.init.normal_(self.v_proj, std=std)
nn.init.normal_(self.c_proj.weight, std=std)
nn.init.zeros_(self.bias)