Skip to content

Commit

Permalink
Updates for OPT
Browse files Browse the repository at this point in the history
  • Loading branch information
zphang committed Jul 23, 2022
1 parent 0709327 commit 7c25eed
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 0 deletions.
7 changes: 7 additions & 0 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,13 @@ Model Arguments
T5 relative positional encoding max distance, default 128.


- **opt_pos_emb_offset**: int

Default = 0

Learned position embedding offset (only used by OPT, where it should be set to 2).



- **no_weight_tying**: bool

Expand Down
13 changes: 13 additions & 0 deletions megatron/model/word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def __init__(

# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
self.opt_pos_emb_offset = neox_args.opt_pos_emb_offset

# For ticking position ids forward
self.layer_past = None

def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
Expand All @@ -114,6 +118,15 @@ def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
if self.use_pos_emb and self.embedding_type in ["learned", "sinusoidal"]:
if self.layer_past is not None:
position_ids = position_ids + self.layer_past + 1

self.layer_past = position_ids[:, -1]

# OPT always adds 2 for some reason, according to the HF implementation
if self.opt_pos_emb_offset:
position_ids = position_ids + self.opt_pos_emb_offset
import pdb; pdb.set_trace()
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
else:
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ class NeoXArgsModel(NeoXArgsTemplate):
T5 relative positional encoding max distance, default 128.
"""

opt_pos_emb_offset: int = 0
"""
Learned position embedding offset (only used by OPT, where it should be set to 2).
"""

no_weight_tying: bool = False
"""
Disables weight tying between embedding weights and final Linear layer
Expand Down

0 comments on commit 7c25eed

Please sign in to comment.