Skip to content

Commit

Permalink
solve some problem (positional encoding matrix size)
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunwoongko committed Oct 28, 2019
1 parent d5a4d5a commit a970643
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 91 deletions.
75 changes: 22 additions & 53 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified __pycache__/data.cpython-36.pyc
Binary file not shown.
Binary file added __pycache__/train.cpython-36.pyc
Binary file not shown.
10 changes: 10 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
@author : Hyunwoong
@when : 2019-10-29
@homepage : https://github.com/gusdnd852
"""
import train
from conf import *

if __name__ == '__main__':
train.run(total_epoch=epoch, best_loss=inf)
Binary file modified models/__pycache__/transformer.cpython-36.pyc
Binary file not shown.
Binary file modified models/embedding/__pycache__/positional_encoding.cpython-36.pyc
Binary file not shown.
Binary file modified models/embedding/__pycache__/transformer_embedding.cpython-36.pyc
Binary file not shown.
22 changes: 14 additions & 8 deletions models/embedding/positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""

import torch
from torch import nn

from conf import *


class PostionalEncoding(nn.Module):
"""
compute sinusoid encoding.
"""

def __init__(self, d_model, max_len):
def __init__(self, d_model, max_len, device):
"""
constructor of sinusoid encoding class
Expand All @@ -24,13 +22,13 @@ def __init__(self, d_model, max_len):
super(PostionalEncoding, self).__init__()

# same size with input matrix (for adding with input matrix)
self.encoding = torch.zeros(max_len, d_model, device=conf_device, requires_grad=False)
self.encoding = torch.zeros(max_len, d_model, device=device, requires_grad=False)

pos = torch.arange(0, max_len, device=conf_device)
pos = torch.arange(0, max_len, device=device)
pos = pos.float().unsqueeze(dim=1)
# 1D => 2D unsqueeze to represent word's position

_2i = torch.arange(0, d_model, step=2, device=conf_device).float()
_2i = torch.arange(0, d_model, step=2, device=device).float()
# 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
# "step=2" means 'i' multiplied with two (same with 2 * i)

Expand All @@ -41,4 +39,12 @@ def __init__(self, d_model, max_len):
self.encoding[:, 1::2] = torch.cos(pos) # if 'i' is odd [1, 3, 5, ... ] => cos

def forward(self, x):
return self.encoding[:, :x.size(1)]
# self.encoding
# [max_len = 512, d_model = 512]

batch_size, seq_len = x.size()
# [batch_size = 128, seq_len = 31]

return self.encoding[:seq_len, :]
# [seq_len = 31, d_model = 512]
# it will add tok_emb : [128, 31, 512]
12 changes: 7 additions & 5 deletions models/embedding/transformer_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TransformerEmbedding(nn.Module):
positional encoding can give positional information to network
"""

def __init__(self, vocab_size, d_model, drop_prob):
def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
"""
class for word embedding that included positional information
Expand All @@ -24,10 +24,12 @@ class for word embedding that included positional information
"""
super(TransformerEmbedding, self).__init__()
self.tok_emb = TokenEmbedding(vocab_size, d_model)
self.pos_emb = PostionalEncoding(d_model)
self.pos_emb = PostionalEncoding(d_model, max_len, device)
self.drop_out = nn.Dropout(p=drop_prob)

def forward(self, x):
embedding = self.tok_emb(x) + self.pos_emb(x)
embedding = self.drop_out(embedding)
return embedding
tok_emb = self.tok_emb(x)
pos_emb = self.pos_emb(x)
print(tok_emb.size(), pos_emb.size())

return self.drop_out(tok_emb + pos_emb)
16 changes: 10 additions & 6 deletions models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,20 @@


class Transformer(nn.Module):
def __init__(self, enc_voc_size, dec_voc_size, d_model, drop_prob):
def __init__(self, enc_voc_size, dec_voc_size, d_model, max_len, drop_prob, device):
super(Transformer, self).__init__()
self.enc_embedding = TransformerEmbedding(vocab_size=enc_voc_size,
d_model=d_model,
drop_prob=drop_prob)
max_len=max_len,
drop_prob=drop_prob,
device=device)

self.dec_embedding = TransformerEmbedding(vocab_size=dec_voc_size,
d_model=d_model,
drop_prob=drop_prob)
max_len=max_len,
drop_prob=drop_prob,
device=device)

def forward(self, x):
x = self.enc_embedding(x)
return x
def forward(self, source, target):
source = self.enc_embedding(source)
return source
Loading

0 comments on commit a970643

Please sign in to comment.