Skip to content

Commit

Permalink
Add GMLP (EleutherAI#339)
Browse files Browse the repository at this point in the history
* gmlp initial commit

* cleanup gmlp

* update requirements.txt

* minor fixes + add gmlp config
  • Loading branch information
sdtblck committed May 19, 2021
1 parent 2e473ab commit c33a2cf
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 76 deletions.
72 changes: 72 additions & 0 deletions configs/gmlp_small.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# GPT-2 pretraining setup
{
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
"pipe-parallel-size": 1,
"model-parallel-size": 1,
"attention_config": [[["gmlp"], "all"]],


# model settings
"num-layers": 12,
"hidden-size": 768, # gmlp d_ff defaults to hidden_size * 4
"gmlp_attn_dim": 64,
"num-attention-heads": 12, # this has no effect with gmlp - and amlp defaults to single head attention.
"seq-length": 2048,
"max-position-embeddings": 2048,
"norm": "layernorm",
"pos-emb": "none",
"no-weight-tying": true,

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0006,
"betas": [0.9, 0.999],
"eps": 1.0e-8,
}
},

# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data-impl": "mmap",
"split": "949,50,1",

# activation checkpointing
"checkpoint-activations": true,
"checkpoint-num-layers": 1,
"partition-activations": false,
"synchronize-each-layer": true,

# regularization
"gradient_clipping": 1.0,
"weight-decay": 0.1,
"hidden-dropout": 0.0,
"attention-dropout": 0.0,

# precision settings
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train-iters": 320000,
"lr-decay-iters": 320000,
"distributed-backend": "nccl",
"lr-decay-style": "cosine",
"warmup": 0.01,
"save-interval": 10000,
"eval-interval": 1000,
"eval-iters": 10,

# logging
"log-interval": 100,
"steps_per_print": 10,
"keep-last-n-checkpoints": 4,
"wall_clock_breakdown": true,
}
108 changes: 108 additions & 0 deletions megatron/model/gmlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.activations import get_activation
from megatron.model.norms import get_norm
from megatron import mpu


class TinyAttention(nn.Module):
def __init__(self, neox_args, d_attn, d_ff, mask_fn):
super().__init__()
self.proj_qkv = nn.Linear(d_ff * 2, 3 * d_attn)
self.scale = d_attn ** -0.5
self.seq_len = neox_args.seq_length
self.proj_ffn = nn.Linear(d_attn, d_ff)
self.softmax = FusedScaleMaskSoftmax(
input_in_fp16=neox_args.precision == "fp16",
upper_triang_mask_fusion=neox_args.scaled_upper_triang_masked_softmax_fusion,
general_mask_fusion=neox_args.scaled_masked_softmax_fusion,
mask_func=mask_fn,
softmax_in_fp32=neox_args.attention_softmax_in_fp32,
scale=None)

def forward(self, x, attention_mask):
q, k, v = torch.chunk(self.proj_qkv(x), 3, dim=-1)
w = torch.einsum("bnd,bmd->bnm", q, k).unsqueeze(1) * self.scale
a = self.softmax(w, mask=attention_mask[..., :w.size(-2), :w.size(-1)]).squeeze(1)
x = torch.einsum("bnm,bmd->bnd", a, v)
return self.proj_ffn(x)


class SpatialGatingUnit(nn.Module):
def __init__(self, neox_args, d_ff, d_attn=None, causal=True, mask_fn=None):
super().__init__()
self.causal = causal
norm, eps = get_norm(neox_args)
self.norm = norm(d_ff, eps=eps)
self.proj = nn.Linear(neox_args.seq_length, neox_args.seq_length)
self.use_attn = d_attn is not None
if self.use_attn:
assert mask_fn is not None
self.attn = TinyAttention(neox_args=neox_args, d_attn=d_attn, d_ff=d_ff, mask_fn=mask_fn)
nn.init.zeros_(self.proj.weight)
nn.init.constant_(self.proj.bias, 1.)

def forward(self, x, attention_mask):
x = x.transpose(0, 1) # [s, b, d] -> [b, s, d]
res, gate = x.chunk(2, dim=-1) # split along dim
gate = self.norm(gate)
weight = self.proj.weight
if self.causal:
mask = torch.ones(weight.shape[:2], device=gate.device).triu_(1).bool()
weight = weight.masked_fill(mask, 0.)
gate = F.linear(gate.transpose(2, 1), weight, self.proj.bias).transpose(2, 1)
if self.use_attn:
gate = gate + self.attn(x, attention_mask)
return (gate * res).transpose(0, 1) # [b, s, d] -> [s, b, d]


class GMLPBlock(nn.Module):
def __init__(self, neox_args, init_method, output_layer_init_method, layer_number, ff_mult=4, mask_fn=None):
super().__init__()
self.layer_number = layer_number
ff_dim = neox_args.hidden_size * ff_mult
norm, eps = get_norm(neox_args)
self.norm = norm(neox_args.hidden_size, eps=eps)
self.input_linear = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim * 2,
gather_output=False,
init_method=init_method,
skip_bias_add=True)
self.activation_func = get_activation(neox_args)
ff_dim_parallel = mpu.divide(ff_dim, mpu.get_model_parallel_world_size())
if neox_args.attention_config[layer_number] == "amlp":
d_attn = neox_args.gmlp_attn_dim
else:
d_attn = None
self.sgu = SpatialGatingUnit(neox_args, ff_dim_parallel, d_attn, causal=True, mask_fn=mask_fn)
self.output_linear = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=ff_dim,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)

def forward(self, args):
in_inference = len(args) == 4
in_train = len(args) == 2
if in_train:
x, attention_mask = args
elif in_inference:
x, layer_past, presents, attention_mask = args
else:
raise ValueError
x = self.norm(x)
x, _ = self.input_linear(x)
x = self.activation_func(x)
x = self.sgu(x, attention_mask)
x, _ = self.output_linear(x)
if in_train:
return x, attention_mask
elif in_inference:
return x, layer_past, presents, attention_mask
74 changes: 35 additions & 39 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@

from functools import partial
from megatron.model.utils import Lambda, SequentialWrapper
from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm
from megatron.model.norms import get_norm
from megatron.model.init_functions import get_init_methods

from megatron import mpu
from megatron.mpu import ParallelRelativePositionBias
import megatron.fp16 as fp16
from megatron.model.transformer import ParallelTransformerLayerPipe, NormPipe, ParallelLinearPipe, parallel_lm_logits
from megatron.model.gmlp import GMLPBlock
from megatron.model.word_embeddings import EmbeddingPipe

# Pipeline parallelism
Expand Down Expand Up @@ -79,15 +80,10 @@ def __init__(self, neox_args, num_tokentypes=0, parallel_output=True, topology=N
self.hidden_size = self.neox_args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method, self.output_layer_init_method = get_init_methods(self.neox_args)
self.fp16_lm_cross_entropy = self.neox_args.fp16_lm_cross_entropy
self.embedding_type = self.neox_args.pos_emb

#
# forward() prototype
#
self.specs = []
self.init_specs()
loss_fn = partial(cross_entropy, _fp16=self.fp16_lm_cross_entropy)
loss_fn = partial(cross_entropy, _fp16=self.neox_args.fp16_lm_cross_entropy)
if self.neox_args.checkpoint_activations:
interval = self.neox_args.checkpoint_num_layers
else:
Expand All @@ -96,19 +92,15 @@ def __init__(self, neox_args, num_tokentypes=0, parallel_output=True, topology=N
loss_fn=loss_fn if not self._inference else None,
topology=topology,
activation_checkpoint_interval=interval,
partition_method='type:transformer')
partition_method=neox_args.pipe_partition_method,
checkpointable_layers=['GMLPBlock', 'ParallelTransformerLayerPipe'])

def init_specs(self):
weight_tying = not self.neox_args.no_weight_tying
if self.embedding_type == 'rpe':
rpe_emb = ParallelRelativePositionBias(neox_args=self.neox_args, causal=True, num_buckets=self.neox_args.rpe_num_buckets,
max_distance=self.neox_args.rpe_max_distance,
heads=self.neox_args.num_attention_heads)
self.fp16_lm_cross_entropy = self.neox_args.fp16_lm_cross_entropy

#
# forward() prototype
#
self.specs = []
# Embedding layer
# input will be (input_ids, position_ids, attention_mask) in Training
Expand Down Expand Up @@ -151,42 +143,46 @@ def init_specs(self):
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:]))

# Transformer layers
for x in range(self.neox_args.num_layers):
self.specs.append(
LayerSpec(
ParallelTransformerLayerPipe,
neox_args=self.neox_args,
attention_mask_func=gpt2_attention_mask_func,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
layer_number=x,
rpe=rpe_emb if self.neox_args.pos_emb == 'rpe' else None,
rotary=self.neox_args.pos_emb == 'rotary',
get_key_value=self.get_key_value
for i in range(self.neox_args.num_layers):
layer_type = self.neox_args.attention_config[i]
if layer_type in ["gmlp", "amlp"]:
self.specs.append(
LayerSpec(
GMLPBlock,
init_method=self.init_method,
layer_number=i,
output_layer_init_method=self.output_layer_init_method,
neox_args=self.neox_args,
mask_fn=gpt2_attention_mask_func
)
)
else:
self.specs.append(
LayerSpec(
ParallelTransformerLayerPipe,
neox_args=self.neox_args,
attention_mask_func=gpt2_attention_mask_func,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
layer_number=i,
rpe=rpe_emb if self.neox_args.pos_emb == 'rpe' else None,
rotary=self.neox_args.pos_emb == 'rotary',
get_key_value=self.get_key_value
)
)

if self._inference:
# we can get rid of the mask / pasts / (?rotary_pos_emb) now
# from (hidden_states, layer_past, presents, (maybe rotary_pos_emb), attention_mask)
# to (hidden_states^T, presents)
# we can get rid of the mask / pasts now
# from (hidden_states, layer_past, presents, attention_mask)
# to (hidden_states.T, presents)
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), x[2]))
else:
# Undo data format change and drop mask
self.specs.append(lambda x: x[0].transpose(0, 1).contiguous())

# Final layernorm after transformer layers
if self.neox_args.norm == "rmsnorm":
norm = RMSNorm
eps = self.neox_args.rms_norm_epsilon
elif self.neox_args.norm == "layernorm":
eps = self.neox_args.layernorm_epsilon
norm = LayerNorm
elif self.neox_args.norm == "scalenorm":
eps = self.neox_args.scalenorm_epsilon
norm = ScaleNorm

# NormPipe is a helper class to pass presents through to the output when doing inference
norm, eps = get_norm(self.neox_args)
self.specs.append(
LayerSpec(NormPipe,
norm,
Expand Down Expand Up @@ -239,7 +235,7 @@ def _logits_helper(embedding, lm_output):
parallel_output=self.parallel_output
)
)
# so output in training should just be logits
# output in training should just be logits
# in inference it will be (logits, presents) (assuming get_key_value) is true

def to_sequential(self):
Expand Down
13 changes: 13 additions & 0 deletions megatron/model/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@
from torch.nn import LayerNorm


def get_norm(neox_args):
if neox_args.norm == "rmsnorm":
norm = RMSNorm
eps = neox_args.rms_norm_epsilon
elif neox_args.norm == "layernorm":
eps = neox_args.layernorm_epsilon
norm = LayerNorm
elif neox_args.norm == "scalenorm":
eps = neox_args.scalenorm_epsilon
norm = ScaleNorm
return norm, eps


class RMSNorm(torch.nn.Module):
def __init__(self, dim, p=-1., eps=1e-8, bias=False):
"""
Expand Down
Loading

0 comments on commit c33a2cf

Please sign in to comment.