Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5rpe #141

Merged
merged 7 commits into from
Feb 28, 2021
Merged

T5rpe #141

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ We run our experiments on a Kubernetes cluster generously provided by [CoreWeave

## Licensing

This repository hosts code that is part of EleutherAI's GPT-NeoX project. Copyright 2021 Stella Biderman, Sid Black, Leo Gao, Josh Levy-Kramer, and Shivanshu Purohit.
This repository hosts code that is part of EleutherAI's GPT-NeoX project. Copyright (c) 2021 Stella Biderman, Sid Black, Josh Levy-Kramer, Michael Pieler, and Shivanshu Purohit.

GPT-NeoX is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
Expand Down
10 changes: 9 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,15 @@ def _add_training_args(parser):
group.add_argument('--no-weight-tying', action='store_true',
help='Disables weight tying between embedding weights and final Linear layer')
group.add_argument('--sinusoidal-pos-emb', action='store_true',
help='Uses Sinusoidal Positional embedding applied to the inputs instead of learned')
help='Uses Sinusoidal Positional embedding applied to the inputs instead of learned')
group.add_argument('--rpe', action='store_true',
help='T5 relative positional encoding')
group.add_argument('--rpe-causal', action='store_true',
help='T5 relative positional encoding causal flag')
group.add_argument('--rpe-num-buckets', type=int, default=32,
help='T5 relative positional encoding number of buckets, default 32.')
group.add_argument('--rpe-max-distance', type=int, default=128,
help='T5 relative positional encoding max distance, default 128.')
group.add_argument('--bias-dropout-fusion', action='store_true',
help='Enable bias and dropout fusion.')
group.add_argument('--sparsity', type=str, default='none',
Expand Down
47 changes: 47 additions & 0 deletions megatron/model/t5rpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Based on: https://github.com/lucidrains/x-transformers/blob/6b93c21be0d0a679da6f7b9621d9bb638ab18428/x_transformers/x_transformers.py#L106 (14.12.2021)

import math
import torch
from torch import nn
from einops import rearrange


class RelativePositionBias(nn.Module):
def __init__(self, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
super().__init__()
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)

@staticmethod
def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))

max_exact = num_buckets // 2
is_small = n < max_exact

val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

ret += torch.where(is_small, n, val_if_large)
return ret

def forward(self, q_len, k_len):
q_pos = torch.arange(q_len, dtype = torch.long, device = torch.cuda.current_device())
k_pos = torch.arange(k_len, dtype = torch.long, device = torch.cuda.current_device())
rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> () h i j')
return bias

34 changes: 30 additions & 4 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from megatron import mpu
from megatron.mpu import LayerNorm, RMSNorm
from megatron.module import MegatronModule
from megatron.model.t5rpe import RelativePositionBias
from megatron.checkpointing import get_checkpoint_version
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
Expand Down Expand Up @@ -164,7 +165,8 @@ class ParallelSelfAttention(MegatronModule):
"""

def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number, sparse=False):
output_layer_init_method, layer_number, sparse=False,
rpe=False):
super(ParallelSelfAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
Expand Down Expand Up @@ -197,6 +199,8 @@ def __init__(self, attention_mask_func, init_method,
coeff = self.layer_number
self.norm_factor *= coeff

self.rpe = rpe

self.sparse = sparse
if self.sparse:
assert args.model_parallel_size <= 1, "TODO: sparsity doesn't yet work with mp size > 1"
Expand Down Expand Up @@ -311,6 +315,9 @@ def forward(self, hidden_states, attention_mask, layer_past=None,
if get_key_value:
present = (key_layer, value_layer)

if self.rpe:
rpe = self.rpe(query_layer.size(0), key_layer.size(0))

if not self.sparse:
# ===================================
# Raw attention scores. [b, np, s, s]
Expand Down Expand Up @@ -366,6 +373,9 @@ def forward(self, hidden_states, attention_mask, layer_past=None,
# Attention probs and dropout
# ===========================

if self.rpe:
attention_scores += rpe # [1, np, sq, sk]

# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
Expand Down Expand Up @@ -464,7 +474,7 @@ class ParallelTransformerLayer(MegatronModule):
"""

def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number, sparse=False):
output_layer_init_method, layer_number, sparse=False, rpe=False):
args = get_args()

super(ParallelTransformerLayer, self).__init__()
Expand All @@ -489,7 +499,8 @@ def __init__(self, attention_mask_func, init_method,
self.attention = ParallelSelfAttention(attention_mask_func, init_method,
output_layer_init_method,
layer_number,
sparse=sparse)
sparse=sparse,
rpe=rpe)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion

Expand Down Expand Up @@ -586,6 +597,11 @@ def __init__(self, attention_mask_func,
super(ParallelTransformer, self).__init__()
args = get_args()

self.rpe = args.rpe
rpe_causal = args.rpe_causal
rpe_num_buckets = args.rpe_num_buckets
rpe_max_distance = args.rpe_max_distance

# Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
Expand All @@ -599,6 +615,16 @@ def __init__(self, attention_mask_func,
'number of layers should be divisible by number of unique layers'
self.param_sharing_style = args.param_sharing_style

# Duplicate from lines 181 because we need it for rpe setup:
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)

if self.rpe:
self.rpe = RelativePositionBias(causal=rpe_causal, num_buckets=rpe_num_buckets, max_distance=rpe_max_distance, heads=self.num_attention_heads_per_partition)

# Transformer layers.
sparsity = args.sparsity

Expand All @@ -613,7 +639,7 @@ def build_layer(layer_number):
raise ValueError(f'Sparsity type {sparsity} not recognized')
return ParallelTransformerLayer(
attention_mask_func, init_method,
output_layer_init_method, layer_number, sparse=sparse)
output_layer_init_method, layer_number, sparse=sparse, rpe=self.rpe)

self.layers = torch.nn.ModuleList(
[build_layer(i + 1) for i in range(self.num_unique_layers)])
Expand Down