Skip to content

Commit

Permalink
Embedding模块
Browse files Browse the repository at this point in the history
  • Loading branch information
moon-hotel committed Jun 28, 2021
1 parent ba663f2 commit d312fca
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions Embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch.nn as nn
import torch
import math


class PositionalEncoding(nn.Module):
r"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
.. math::
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
\text{where pos is the word position and i is the embed idx)
Args:
d_model: the embed dim (required).
dropout: the dropout value (default=0.1).
max_len: the max. length of the incoming sequence (default=5000).
Examples:
#>>> pos_encoder = PositionalEncoding(d_model)
"""

def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model) # [max_len, d_model]
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # [d_model/2]
pe[:, 0::2] = torch.sin(position * div_term) # [max_len, d_model/2]
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1) # [max_len, 1, d_model]
self.register_buffer('pe', pe)

def forward(self, x): # [x_len, batch_size, d_model]
"""
:param x: [x_len, batch_size, emb_size]
:return: [x_len, batch_size, emb_size]
"""
x = x + self.pe[:x.size(0), :] # [batch_size, max_len, d_model]
return self.dropout(x)


class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, emb_size):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size

"""
:param tokens: shape : [len, batch_size]
:return: shape: [len, batch_size, emb_size]
"""

def forward(self, tokens):
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


if __name__ == '__main__':
x = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], dtype=torch.long)
x = x.reshape(5, 2) # [src_len, batch_size]
token_embedding = TokenEmbedding(vocab_size=11, emb_size=512)
x = token_embedding(tokens=x)
pos_embedding = PositionalEncoding(d_model=512)
x = pos_embedding(x=x)
print(x.shape)

0 comments on commit d312fca

Please sign in to comment.