Skip to content

Commit

Permalink
Merge pull request #141 from EleutherAI/t5rpe
Browse files Browse the repository at this point in the history
T5rpe
  • Loading branch information
joshlk committed Feb 28, 2021
2 parents d920e25 + 3685021 commit 87944a8
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ We run our experiments on a Kubernetes cluster generously provided by [CoreWeave

## Licensing

This repository hosts code that is part of EleutherAI's GPT-NeoX project. Copyright 2021 Stella Biderman, Sid Black, Leo Gao, Josh Levy-Kramer, and Shivanshu Purohit.
This repository hosts code that is part of EleutherAI's GPT-NeoX project. Copyright (c) 2021 Stella Biderman, Sid Black, Josh Levy-Kramer, Michael Pieler, and Shivanshu Purohit.

GPT-NeoX is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
Expand Down
10 changes: 9 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,15 @@ def _add_training_args(parser):
group.add_argument('--no-weight-tying', action='store_true',
help='Disables weight tying between embedding weights and final Linear layer')
group.add_argument('--sinusoidal-pos-emb', action='store_true',
help='Uses Sinusoidal Positional embedding applied to the inputs instead of learned')
help='Uses Sinusoidal Positional embedding applied to the inputs instead of learned')
group.add_argument('--rpe', action='store_true',
help='T5 relative positional encoding')
group.add_argument('--rpe-causal', action='store_true',
help='T5 relative positional encoding causal flag')
group.add_argument('--rpe-num-buckets', type=int, default=32,
help='T5 relative positional encoding number of buckets, default 32.')
group.add_argument('--rpe-max-distance', type=int, default=128,
help='T5 relative positional encoding max distance, default 128.')
group.add_argument('--bias-dropout-fusion', action='store_true',
help='Enable bias and dropout fusion.')
group.add_argument('--sparsity', type=str, default='none',
Expand Down
47 changes: 47 additions & 0 deletions megatron/model/t5rpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Based on: https://github.com/lucidrains/x-transformers/blob/6b93c21be0d0a679da6f7b9621d9bb638ab18428/x_transformers/x_transformers.py#L106 (14.12.2021)

import math
import torch
from torch import nn
from einops import rearrange


class RelativePositionBias(nn.Module):
def __init__(self, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
super().__init__()
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)

@staticmethod
def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))

max_exact = num_buckets // 2
is_small = n < max_exact

val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

ret += torch.where(is_small, n, val_if_large)
return ret

def forward(self, q_len, k_len):
q_pos = torch.arange(q_len, dtype = torch.long, device = torch.cuda.current_device())
k_pos = torch.arange(k_len, dtype = torch.long, device = torch.cuda.current_device())
rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> () h i j')
return bias

34 changes: 30 additions & 4 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from megatron import mpu
from megatron.mpu import LayerNorm, RMSNorm
from megatron.module import MegatronModule
from megatron.model.t5rpe import RelativePositionBias
from megatron.checkpointing import get_checkpoint_version
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
Expand Down Expand Up @@ -164,7 +165,8 @@ class ParallelSelfAttention(MegatronModule):
"""

def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number, sparse=False):
output_layer_init_method, layer_number, sparse=False,
rpe=False):
super(ParallelSelfAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
Expand Down Expand Up @@ -197,6 +199,8 @@ def __init__(self, attention_mask_func, init_method,
coeff = self.layer_number
self.norm_factor *= coeff

self.rpe = rpe

self.sparse = sparse
if self.sparse:
assert args.model_parallel_size <= 1, "TODO: sparsity doesn't yet work with mp size > 1"
Expand Down Expand Up @@ -311,6 +315,9 @@ def forward(self, hidden_states, attention_mask, layer_past=None,
if get_key_value:
present = (key_layer, value_layer)

if self.rpe:
rpe = self.rpe(query_layer.size(0), key_layer.size(0))

if not self.sparse:
# ===================================
# Raw attention scores. [b, np, s, s]
Expand Down Expand Up @@ -366,6 +373,9 @@ def forward(self, hidden_states, attention_mask, layer_past=None,
# Attention probs and dropout
# ===========================

if self.rpe:
attention_scores += rpe # [1, np, sq, sk]

# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
Expand Down Expand Up @@ -464,7 +474,7 @@ class ParallelTransformerLayer(MegatronModule):
"""

def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number, sparse=False):
output_layer_init_method, layer_number, sparse=False, rpe=False):
args = get_args()

super(ParallelTransformerLayer, self).__init__()
Expand All @@ -489,7 +499,8 @@ def __init__(self, attention_mask_func, init_method,
self.attention = ParallelSelfAttention(attention_mask_func, init_method,
output_layer_init_method,
layer_number,
sparse=sparse)
sparse=sparse,
rpe=rpe)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion

Expand Down Expand Up @@ -586,6 +597,11 @@ def __init__(self, attention_mask_func,
super(ParallelTransformer, self).__init__()
args = get_args()

self.rpe = args.rpe
rpe_causal = args.rpe_causal
rpe_num_buckets = args.rpe_num_buckets
rpe_max_distance = args.rpe_max_distance

# Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
Expand All @@ -599,6 +615,16 @@ def __init__(self, attention_mask_func,
'number of layers should be divisible by number of unique layers'
self.param_sharing_style = args.param_sharing_style

# Duplicate from lines 181 because we need it for rpe setup:
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)

if self.rpe:
self.rpe = RelativePositionBias(causal=rpe_causal, num_buckets=rpe_num_buckets, max_distance=rpe_max_distance, heads=self.num_attention_heads_per_partition)

# Transformer layers.
sparsity = args.sparsity

Expand All @@ -613,7 +639,7 @@ def build_layer(layer_number):
raise ValueError(f'Sparsity type {sparsity} not recognized')
return ParallelTransformerLayer(
attention_mask_func, init_method,
output_layer_init_method, layer_number, sparse=sparse)
output_layer_init_method, layer_number, sparse=sparse, rpe=self.rpe)

self.layers = torch.nn.ModuleList(
[build_layer(i + 1) for i in range(self.num_unique_layers)])
Expand Down

0 comments on commit 87944a8

Please sign in to comment.