Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the self-attention operator replaceable in Transformer #334

Merged
merged 3 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions pypots/imputation/crossformer/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
from einops import rearrange, repeat

from ....nn.modules.transformer import MultiHeadAttention
from ....nn.modules.transformer import ScaledDotProductAttention, MultiHeadAttention


class TwoStageAttentionLayer(nn.Module):
Expand All @@ -33,10 +33,26 @@ def __init__(
super().__init__()
d_ff = 4 * d_model if d_ff is None else d_ff
self.time_attention = MultiHeadAttention(
n_heads, d_model, d_k, d_v, attn_dropout
n_heads,
d_model,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
)
self.dim_sender = MultiHeadAttention(
n_heads,
d_model,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
)
self.dim_receiver = MultiHeadAttention(
n_heads,
d_model,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
)
self.dim_sender = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.dim_receiver = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))

self.dropout = nn.Dropout(dropout)
Expand Down
3 changes: 2 additions & 1 deletion pypots/imputation/patchtst/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn

from .submodules import PatchEmbedding, FlattenHead
from ....nn.modules.transformer.attention import ScaledDotProductAttention
from ....nn.modules.transformer.auto_encoder import EncoderLayer
from ....utils.metrics import calc_mse

Expand Down Expand Up @@ -49,8 +50,8 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand Down
5 changes: 3 additions & 2 deletions pypots/imputation/saits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn.functional as F

from ....nn.modules.transformer import EncoderLayer, PositionalEncoding
from ....nn.modules.transformer.attention import ScaledDotProductAttention
from ....utils.metrics import calc_mae


Expand Down Expand Up @@ -59,8 +60,8 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand All @@ -73,8 +74,8 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand Down
3 changes: 2 additions & 1 deletion pypots/imputation/transformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.nn as nn

from ....nn.modules.transformer import EncoderLayer, PositionalEncoding
from ....nn.modules.transformer.attention import ScaledDotProductAttention
from ....utils.metrics import calc_mae


Expand Down Expand Up @@ -52,8 +53,8 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand Down
14 changes: 0 additions & 14 deletions pypots/modules/__init__.py

This file was deleted.

43 changes: 31 additions & 12 deletions pypots/nn/modules/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,30 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import abstractmethod


class ScaledDotProductAttention(nn.Module):
class AttentionOperator(nn.Module):
"""
The abstract class for all attention layers.
"""

def __init__(self):
super().__init__()

@abstractmethod
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError


class ScaledDotProductAttention(AttentionOperator):
"""Scaled dot-product attention.

Parameters
Expand All @@ -44,15 +65,18 @@ def forward(
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward processing of the scaled dot-product attention.

Parameters
----------
q:
Query tensor.

k:
Key tensor.

v:
Value tensor.

Expand Down Expand Up @@ -106,11 +130,8 @@ class MultiHeadAttention(nn.Module):
d_v:
The dimension of the value tensor.

attn_dropout:
The dropout rate for the attention map.

attn_temperature:
The temperature for scaling. Default is None, which means d_k**0.5 will be applied.
attention_operator:
The attention operator, e.g. the self-attention proposed in Transformer.

"""

Expand All @@ -120,13 +141,10 @@ def __init__(
d_model: int,
d_k: int,
d_v: int,
attn_dropout: float,
attn_temperature: float = None,
attention_operator: AttentionOperator,
):
super().__init__()

attn_temperature = d_k**0.5 if attn_temperature is None else attn_temperature

self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
Expand All @@ -135,7 +153,7 @@ def __init__(
self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False)

self.attention = ScaledDotProductAttention(attn_temperature, attn_dropout)
self.attention_operator = attention_operator
self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

def forward(
Expand All @@ -144,6 +162,7 @@ def forward(
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor],
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward processing of the multi-head attention module.

Expand Down Expand Up @@ -189,7 +208,7 @@ def forward(
# broadcasting on the head axis
attn_mask = attn_mask.unsqueeze(1)

v, attn_weights = self.attention(q, k, v, attn_mask)
v, attn_weights = self.attention_operator(q, k, v, attn_mask, **kwargs)

# transpose back -> [batch_size, n_steps, n_heads, d_v]
# then merge the last two dimensions to combine all the heads -> [batch_size, n_steps, n_heads*d_v]
Expand Down
6 changes: 4 additions & 2 deletions pypots/nn/modules/transformer/auto_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn

from .attention import ScaledDotProductAttention
from .embedding import PositionalEncoding
from .layers import EncoderLayer, DecoderLayer

Expand Down Expand Up @@ -78,8 +79,8 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand Down Expand Up @@ -190,8 +191,9 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand Down
42 changes: 29 additions & 13 deletions pypots/nn/modules/transformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn as nn
import torch.nn.functional as F

from .attention import MultiHeadAttention
from .attention import MultiHeadAttention, AttentionOperator


class PositionWiseFeedForward(nn.Module):
Expand Down Expand Up @@ -85,11 +85,12 @@ class EncoderLayer(nn.Module):
d_v:
The dimension of the value tensor.

slf_attn_opt:
The attention operator for the self multi-head attention module in the encoder layer.

dropout:
The dropout rate.

attn_dropout:
The dropout rate for the attention map.
"""

def __init__(
Expand All @@ -99,11 +100,11 @@ def __init__(
n_heads: int,
d_k: int,
d_v: int,
slf_attn_opt: AttentionOperator,
dropout: float = 0.1,
attn_dropout: float = 0.1,
):
super().__init__()
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, slf_attn_opt)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout)
Expand All @@ -112,6 +113,7 @@ def forward(
self,
enc_input: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward processing of the encoder layer.

Expand All @@ -137,6 +139,7 @@ def forward(
enc_input,
enc_input,
attn_mask=src_mask,
**kwargs,
)

# apply dropout and residual connection
Expand Down Expand Up @@ -170,12 +173,15 @@ class DecoderLayer(nn.Module):
d_v:
The dimension of the value tensor.

slf_attn_opt:
The attention operator for the self multi-head attention module in the decoder layer.

enc_attn_opt:
The attention operator for the encoding multi-head attention module in the decoder layer.

dropout:
The dropout rate.

attn_dropout:
The dropout rate for the attention map.

"""

def __init__(
Expand All @@ -185,12 +191,13 @@ def __init__(
n_heads: int,
d_k: int,
d_v: int,
slf_attn_opt: AttentionOperator,
enc_attn_opt: AttentionOperator,
dropout: float = 0.1,
attn_dropout: float = 0.1,
):
super().__init__()
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.enc_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, slf_attn_opt)
self.enc_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, enc_attn_opt)
self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout)

def forward(
Expand All @@ -199,6 +206,7 @@ def forward(
enc_output: torch.Tensor,
slf_attn_mask: Optional[torch.Tensor] = None,
dec_enc_attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward processing of the decoder layer.

Expand Down Expand Up @@ -231,10 +239,18 @@ def forward(

"""
dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, attn_mask=slf_attn_mask
dec_input,
dec_input,
dec_input,
attn_mask=slf_attn_mask,
**kwargs,
)
dec_output, dec_enc_attn = self.enc_attn(
dec_output, enc_output, enc_output, attn_mask=dec_enc_attn_mask
dec_output,
enc_output,
enc_output,
attn_mask=dec_enc_attn_mask,
**kwargs,
)
dec_output = self.pos_ffn(dec_output)
return dec_output, dec_slf_attn, dec_enc_attn
Loading