Skip to content

Commit

Permalink
Merge pull request #3 from EleutherAI/add_geglu
Browse files Browse the repository at this point in the history
Add geglu
  • Loading branch information
StellaAthena committed Feb 8, 2021
2 parents 624df98 + 7f194de commit 046037e
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 32 deletions.
3 changes: 3 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ def _add_training_args(parser):
'general masking and softmax.')
group.add_argument('--bias-gelu-fusion', action='store_true',
help='Enable bias and gelu fusion.')
group.add_argument('--geglu', action='store_true',
help='Enable geglu activation function (WARNING: will increase memory usage, '
'adjust embd dims accordingly)')
group.add_argument('--bias-dropout-fusion', action='store_true',
help='Enable bias and dropout fusion.')

Expand Down
18 changes: 6 additions & 12 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def CrossEntropy(output, labels):
return loss



class GPT2Model(MegatronModule):
"""GPT-2 Language model."""

Expand Down Expand Up @@ -109,14 +108,13 @@ def forward(self, input_ids, position_ids, attention_mask, labels=None,
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
return loss


def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):

state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
destination, prefix, keep_vars)
return state_dict_

def load_state_dict(self, state_dict, strict=True):
Expand All @@ -127,25 +125,22 @@ def load_state_dict(self, state_dict, strict=True):
self.language_model.load_state_dict(state_dict, strict=strict)


class GPT2ModelPipe(PipelineModule,MegatronModule):
class GPT2ModelPipe(PipelineModule, MegatronModule):
"""GPT2Model adapted for pipeline parallelism.
The largest change is flattening the GPTModel class so we can express it as a
sequence of layers including embedding, transformer layers, and output.
"""

def __init__(self, num_tokentypes=0, parallel_output=True, add_pooler=False, topology=None):
def __init__(self, num_tokentypes=0, parallel_output=True, topology=None):
args = get_args()

self.parallel_output = parallel_output
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method_normal(args.init_method_std)
self.output_layer_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
self.add_pooler = add_pooler
if self.add_pooler:
raise NotImplementedError('Pipeline pooler not yet implemented. Forward needs pooling_sequence_index')


# Use torch gelu unless otherwise forced.
gelu = F.gelu
if args.openai_gelu:
Expand All @@ -170,7 +165,7 @@ def __init__(self, num_tokentypes=0, parallel_output=True, add_pooler=False, top
# outputs are now (hidden_states, attention_mask)

# data format change to avoid explicit tranposes : [b s h] --> [s b h]
self.specs.append(lambda x: (x[0].transpose(0,1).contiguous(), x[1]))
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), x[1]))

# Transformer layers
for x in range(args.num_layers):
Expand All @@ -181,8 +176,7 @@ def __init__(self, num_tokentypes=0, parallel_output=True, add_pooler=False, top
output_layer_init_method=self.output_layer_init_method,
layer_number=x))
# Undo data format change and drop mask
self.specs.append(lambda x: x[0].transpose(0,1).contiguous())

self.specs.append(lambda x: x[0].transpose(0, 1).contiguous())

# Final layernorm after transformer layers
self.specs.append(
Expand Down
20 changes: 12 additions & 8 deletions megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal


def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None):
"""LM logits using word embedding weights."""
Expand Down Expand Up @@ -187,11 +188,11 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars)
destination, prefix, keep_vars)

return state_dict_

Expand Down Expand Up @@ -240,13 +241,15 @@ def load_state_dict(self, state_dict, strict=True):
print('***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it', flush=True)


class EmbeddingPipe(Embedding):
"""Extends Embedding to forward attention_mask through the pipeline."""

@property
def word_embeddings_weight(self):
"""Easy accessory for the pipeline engine to tie embeddings across stages."""
return self.word_embeddings.weight

def forward(self, args):
input_ids = args[0]
position_ids = args[1]
Expand All @@ -255,10 +258,11 @@ def forward(self, args):
tokentype_ids = args[3]
else:
tokentype_ids = None

embeddings = super().forward(input_ids, position_ids, tokentype_ids=tokentype_ids)
return embeddings, attention_mask


class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
Expand Down Expand Up @@ -303,7 +307,7 @@ def __init__(self,

# Transformer
self.transformer = ParallelTransformer(
attention_mask_func, self.init_method,
attention_mask_func, self.init_method,
output_layer_init_method)
self._transformer_key = 'transformer'

Expand Down Expand Up @@ -340,14 +344,14 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
state_dict_ = {}
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
destination, prefix, keep_vars)
state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
destination, prefix, keep_vars)
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
destination, prefix, keep_vars)

return state_dict_

Expand Down
62 changes: 50 additions & 12 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,31 @@
"""


class GEGLU(MegatronModule):

def __init__(self):
super(GEGLU, self).__init__()
args = get_args()
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu

def forward(self, x, bias):
x, gate = x.chunk(2, dim=-1)
bias_1, bias_2 = bias.chunk(2, dim=-1)
x = x + bias_1
if self.bias_gelu_fusion:
intermediate_parallel = \
bias_gelu_impl(gate, bias_2)
else:
intermediate_parallel = \
self.activation_func(gate + bias_2)
return intermediate_parallel * x


class ParallelMLP(MegatronModule):
"""MLP.
Expand All @@ -71,21 +96,28 @@ def __init__(self, init_method, output_layer_init_method):
super(ParallelMLP, self).__init__()
args = get_args()

if args.geglu:
self.activation_type = "geglu"
mult = 8
self.activation_func = GEGLU()
else:
self.activation_type = "gelu"
mult = 4
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu

# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4 * args.hidden_size,
mult * args.hidden_size,
gather_output=False,
init_method=init_method,
skip_bias_add=True)

self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu

# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size,
Expand All @@ -99,12 +131,18 @@ def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

if self.bias_gelu_fusion:
if self.activation_type == "gelu":
if self.bias_gelu_fusion:
intermediate_parallel = \
bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = \
self.activation_func(intermediate_parallel + bias_parallel)
elif self.activation_type == "geglu":
intermediate_parallel = \
bias_gelu_impl(intermediate_parallel, bias_parallel)
self.activation_func(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = \
self.activation_func(intermediate_parallel + bias_parallel)
raise ValueError(f'Activation type {self.activation_type} not recognized')

# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
Expand Down

0 comments on commit 046037e

Please sign in to comment.