From 96404bbc21ea1f54b0107b437beb920437595b3c Mon Sep 17 00:00:00 2001 From: sdtblck Date: Mon, 8 Feb 2021 14:28:59 +0000 Subject: [PATCH] add sinusoidal positional embedding --- megatron/arguments.py | 4 ++- megatron/model/gpt2_model.py | 5 ++- megatron/model/language_model.py | 61 +++++++++++++++++++++----------- 3 files changed, 48 insertions(+), 22 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index ffb5511be..8951eb0e1 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -263,7 +263,9 @@ def _add_training_args(parser): help='Enable geglu activation function (WARNING: will increase memory usage, ' 'adjust embd dims accordingly)') group.add_argument('--no-weight-tying', action='store_true', - help='Disables weight tying between embedding weights and final Linear layer') + help='Disables weight tying between embedding weights and final Linear layer') + group.add_argument('--sinusoidal-pos-emb', action='store_true', + help='Uses Sinusoidal Positional embedding applied to the inputs instead of learned') group.add_argument('--bias-dropout-fusion', action='store_true', help='Enable bias and dropout fusion.') diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 563e4abed..1f717f05c 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -168,6 +168,7 @@ def __init__(self, num_tokentypes=0, parallel_output=True, topology=None): args.hidden_dropout, self.init_method, self.num_tokentypes, + args.sinusoidal_pos_emb, tied_weight_attr='word_embeddings_weight')) else: self.specs.append(LayerSpec(EmbeddingPipe, @@ -176,7 +177,8 @@ def __init__(self, num_tokentypes=0, parallel_output=True, topology=None): args.max_position_embeddings, args.hidden_dropout, self.init_method, - self.num_tokentypes)) + self.num_tokentypes, + args.sinusoidal_pos_emb)) # outputs are now (hidden_states, attention_mask) @@ -220,6 +222,7 @@ def _logits_helper(embedding, lm_output): args.hidden_dropout, self.init_method, self.num_tokentypes, + args.sinusoidal_pos_emb, forward_fn=_logits_helper, tied_weight_attr='word_embeddings_weight') ) diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 0e591262f..b0a83fb8c 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -103,6 +103,19 @@ def forward(self, hidden_states, sequence_index=0): return pooled +class SinusoidalPositionalEmbedding(MegatronModule): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x): + t = torch.arange(x.shape[1], device=x.device).type_as(self.inv_freq) + sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + class Embedding(MegatronModule): """Language model embeddings. @@ -123,7 +136,8 @@ def __init__(self, max_sequence_length, embedding_dropout_prob, init_method, - num_tokentypes=0): + num_tokentypes=0, + sinusoidal_positional_embedding=False): super(Embedding, self).__init__() self.hidden_size = hidden_size @@ -136,11 +150,15 @@ def __init__(self, self._word_embeddings_key = 'word_embeddings' # Position embedding (serial). - self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size) - self._position_embeddings_key = 'position_embeddings' - # Initialize the position embeddings. - self.init_method(self.position_embeddings.weight) + self.sinusoidal_positional_embedding = sinusoidal_positional_embedding + if not self.sinusoidal_positional_embedding: + self.position_embeddings = torch.nn.Embedding( + max_sequence_length, self.hidden_size) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + self.init_method(self.position_embeddings.weight) + else: + self.position_embeddings = SinusoidalPositionalEmbedding(self.hidden_size) # Token type embedding. # Add this as an optional field that can be added through @@ -197,9 +215,10 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='', state_dict_ = {} state_dict_[self._word_embeddings_key] \ = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) + if not self.sinusoidal_positional_embedding: + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict( + destination, prefix, keep_vars) if self.num_tokentypes > 0: state_dict_[self._tokentype_embeddings_key] \ = self.tokentype_embeddings.state_dict( @@ -223,16 +242,17 @@ def load_state_dict(self, state_dict, strict=True): self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. - if self._position_embeddings_key in state_dict: - state_dict_ = state_dict[self._position_embeddings_key] - else: - # for backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] - self.position_embeddings.load_state_dict(state_dict_, strict=strict) + if not self.sinusoidal_positional_embedding: + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. if self.num_tokentypes > 0: @@ -313,7 +333,8 @@ def __init__(self, args.max_position_embeddings, args.hidden_dropout, self.init_method, - self.num_tokentypes) + self.num_tokentypes, + args.sinusoidal_pos_emb) self._embedding_key = 'embedding' # Transformer