Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5rpe #141

Merged
merged 7 commits into from
Feb 28, 2021
Merged

T5rpe #141

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update to work with args 2.
  • Loading branch information
MicPie committed Feb 21, 2021
commit 065f138a2adcbd7e621c770480c52d51b83e2de7
8 changes: 8 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ def _add_training_args(parser):
help='Disables weight tying between embedding weights and final Linear layer')
group.add_argument('--sinusoidal-pos-emb', action='store_true',
help='Uses Sinusoidal Positional embedding applied to the inputs instead of learned')
group.add_argument('--rpe', action='store_true',
help='T5 relative positional encoding')
group.add_argument('--rpe-causal', action='store_true',
help='T5 relative positional encoding causal flag')
group.add_argument('--rpe-num-buckets', type=int, default=32,
help='T5 relative positional encoding number of buckets, default 32.')
group.add_argument('--rpe-max-distance', type=int, default=128,
help='T5 relative positional encoding max distance, default 128.')
group.add_argument('--bias-dropout-fusion', action='store_true',
help='Enable bias and dropout fusion.')
group.add_argument('--sparsity', type=str, default='none',
Expand Down
8 changes: 4 additions & 4 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,10 +591,10 @@ def __init__(self, attention_mask_func,
super(ParallelTransformer, self).__init__()
args = get_args()

self.rpe = args.rpe # True
self.rpe_causal = args.rpe_causal # False
self.rpe_num_buckets =.args.rpe_num_buckets # 32
self.rpe_max_distance = args.rpe_max_distance # 128
rpe = args.rpe
rpe_causal = args.rpe_causal
rpe_num_buckets = args.rpe_num_buckets
rpe_max_distance = args.rpe_max_distance

# Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations
Expand Down