Skip to content

Commit

Permalink
support position interpolation for langer attention context windown l…
Browse files Browse the repository at this point in the history
…ength.
  • Loading branch information
zh794390558 committed Jul 13, 2023
1 parent 3b6b680 commit b91b1c9
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions paddlespeech/s2t/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,40 @@ def forward(self, x: paddle.Tensor,
x = x * self.xscale
pos_emb = self.pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb)


# RotaryRelPositionalEncoding is same to RelPositionalEncoding
class ScaledRotaryRelPositionalEncoding(RelPositionalEncoding):
"""Scaled Rotary Relative positional encoding module.
POSITION INTERPOLATION: : https://arxiv.org/pdf/2306.15595v2.pdf
"""

def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int=5000,
scale=1):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
scale (int): Interpolation max input length to `scale * max_len` positions.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
self.scale = scale
self.max_len = max_len * scale

position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
# position interpoloation
position *= 1.0 / self.scale

# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term = paddle.exp(
-paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
(math.log(self.base) / self.d_model))

# [B,T,D]
self.pe[:, :, 0::2] = paddle.sin(position * div_term)
self.pe[:, :, 1::2] = paddle.cos(position * div_term)

0 comments on commit b91b1c9

Please sign in to comment.