Skip to content

Commit

Permalink
Merge pull request #7 from EleutherAI/sinusoid_pos_emb
Browse files Browse the repository at this point in the history
Add Sinusoidal Positional Embedding
  • Loading branch information
StellaAthena committed Feb 8, 2021
2 parents a85b29c + 96404bb commit 261b0ec
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 22 deletions.
4 changes: 3 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def _add_training_args(parser):
help='Enable geglu activation function (WARNING: will increase memory usage, '
'adjust embd dims accordingly)')
group.add_argument('--no-weight-tying', action='store_true',
help='Disables weight tying between embedding weights and final Linear layer')
help='Disables weight tying between embedding weights and final Linear layer')
group.add_argument('--sinusoidal-pos-emb', action='store_true',
help='Uses Sinusoidal Positional embedding applied to the inputs instead of learned')
group.add_argument('--bias-dropout-fusion', action='store_true',
help='Enable bias and dropout fusion.')

Expand Down
5 changes: 4 additions & 1 deletion megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(self, num_tokentypes=0, parallel_output=True, topology=None):
args.hidden_dropout,
self.init_method,
self.num_tokentypes,
args.sinusoidal_pos_emb,
tied_weight_attr='word_embeddings_weight'))
else:
self.specs.append(LayerSpec(EmbeddingPipe,
Expand All @@ -176,7 +177,8 @@ def __init__(self, num_tokentypes=0, parallel_output=True, topology=None):
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes))
self.num_tokentypes,
args.sinusoidal_pos_emb))

# outputs are now (hidden_states, attention_mask)

Expand Down Expand Up @@ -220,6 +222,7 @@ def _logits_helper(embedding, lm_output):
args.hidden_dropout,
self.init_method,
self.num_tokentypes,
args.sinusoidal_pos_emb,
forward_fn=_logits_helper,
tied_weight_attr='word_embeddings_weight')
)
Expand Down
61 changes: 41 additions & 20 deletions megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ def forward(self, hidden_states, sequence_index=0):
return pooled


class SinusoidalPositionalEmbedding(MegatronModule):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)

def forward(self, x):
t = torch.arange(x.shape[1], device=x.device).type_as(self.inv_freq)
sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]


class Embedding(MegatronModule):
"""Language model embeddings.
Expand All @@ -123,7 +136,8 @@ def __init__(self,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0):
num_tokentypes=0,
sinusoidal_positional_embedding=False):
super(Embedding, self).__init__()

self.hidden_size = hidden_size
Expand All @@ -136,11 +150,15 @@ def __init__(self,
self._word_embeddings_key = 'word_embeddings'

# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
self.sinusoidal_positional_embedding = sinusoidal_positional_embedding
if not self.sinusoidal_positional_embedding:
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
else:
self.position_embeddings = SinusoidalPositionalEmbedding(self.hidden_size)

# Token type embedding.
# Add this as an optional field that can be added through
Expand Down Expand Up @@ -197,9 +215,10 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
if not self.sinusoidal_positional_embedding:
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
Expand All @@ -223,16 +242,17 @@ def load_state_dict(self, state_dict, strict=True):
self.word_embeddings.load_state_dict(state_dict_, strict=strict)

# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
if not self.sinusoidal_positional_embedding:
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)

# Tokentype embedding.
if self.num_tokentypes > 0:
Expand Down Expand Up @@ -313,7 +333,8 @@ def __init__(self,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes)
self.num_tokentypes,
args.sinusoidal_pos_emb)
self._embedding_key = 'embedding'

# Transformer
Expand Down

0 comments on commit 261b0ec

Please sign in to comment.