Skip to content

Commit

Permalink
align gpt-j layernorm to hf (#481)
Browse files Browse the repository at this point in the history
* align gpt-j layernorm to hf

* Allows for configuring the tying of GPT-J-style residuals

* Passed new config correctly

* Added config to toggle tied residual

* Update NeoXArgs docs automatically

* Raise error in HF conversion script if =True

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

Co-authored-by: Samuel Weinbach <[email protected]>
Co-authored-by: Shivanshu Purohit <[email protected]>
Co-authored-by: Stella Biderman <[email protected]>
Co-authored-by: github-actions <[email protected]>
Co-authored-by: haileyschoelkopf <[email protected]>
Co-authored-by: Hailey Schoelkopf <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
8 people authored Dec 6, 2022
1 parent aae696d commit 0accac6
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 9 deletions.
14 changes: 13 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 94c66b3
Default = bcb277e

current git hash of repository

Expand Down Expand Up @@ -535,6 +535,18 @@ Model Arguments



- **gpt_j_tied**: bool

Default = False

If false, we use
x = x + attn(ln1(x)) + mlp(ln2(x))
Otherwise, we tie the layer norms
y = ln(x)
x = x + attn(y) + mlp(y)



- **soft_prompt_tuning**: dict

Default = None
Expand Down
28 changes: 20 additions & 8 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def __init__(
self.hidden_dropout = neox_args.hidden_dropout
self.bias_dropout_fusion = neox_args.bias_dropout_fusion
self.gpt_j_residual = neox_args.gpt_j_residual
self.gpt_j_tied = neox_args.gpt_j_tied

if self.gpt_j_residual:
self.reduce = mpu.mappings.reduce_from_model_parallel_region
Expand All @@ -562,6 +563,8 @@ def __init__(
)

# Layernorm on the output of the attention layer.
# If GPT-J residuals are used, this is surpurfulous but leaving it in
# leads to cleaner code
self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps)

# MLP
Expand Down Expand Up @@ -591,14 +594,23 @@ def forward(self, x, attention_mask, layer_past=None):
# x: [b, s, h]
if self.gpt_j_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
# x = x + attn(ln(x)) + mlp(ln(x))
# this means we can avoid doing the allreduce in the attn / mlp outputs
# to save communication time (we can do a single allreduce after we add mlp / attn outputs).

# attention_output = attn(ln1(x))
# due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but
# we preserve the functionality for backwards compatibility

residual = x
# applies the correct normalization depending on if the norms are tied
if self.gpt_j_tied:
x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x)
else:
x = self.input_layernorm(x)
x1, x2 = x, x

# attention operator
attention_output, attention_bias = self.attention(
self.input_layernorm(x), attention_mask, layer_past=layer_past
x1, attention_mask, layer_past=layer_past
)
if self.use_cache:
attention_output, presents = attention_output
Expand All @@ -612,8 +624,8 @@ def forward(self, x, attention_mask, layer_past=None):
prob=self.hidden_dropout,
)

# output = mlp(ln2(x)) + attention_output
mlp_output, mlp_bias = self.mlp(self.post_attention_layernorm(x))
# mlp operator
mlp_output, mlp_bias = self.mlp(x2)
with torch.enable_grad():
output = bias_dropout_fn(
mlp_output,
Expand All @@ -622,8 +634,8 @@ def forward(self, x, attention_mask, layer_past=None):
prob=self.hidden_dropout,
)

# output = output + residual
output = residual + self.reduce(output)
# output = (x + attn(ln(x)) + mlp(ln(x))
output = residual + self.reduce(output)
else:
# pseudocode:
# x = x + attn(ln1(x))
Expand Down
9 changes: 9 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,15 @@ class NeoXArgsModel(NeoXArgsTemplate):
x = ln(x)
x = x + attn(x) + mlp(x)
"""

gpt_j_tied: bool = False
"""
If false, we use
x = x + attn(ln1(x)) + mlp(ln2(x))
Otherwise, we tie the layer norms
y = ln(x)
x = x + attn(y) + mlp(y)
"""

soft_prompt_tuning: dict = None
"""
Expand Down
10 changes: 10 additions & 0 deletions tools/convert_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ def __init__(self, neox_config):
except:
pad_token = 1 # pad defaulting to 1. follows convention from GPT-NeoX-20b tokenizer


# TODO: change the default value here based on discussion regarding `gpt_j_tied` config parameter's default
use_tied_lns = get_key(neox_config, 'gpt-j-tied', False)

if not use_tied_lns:
raise NotImplementedError(
"""ERROR: Huggingface Transformers does not yet support a single shared layernorm
per transformer block for GPT-NeoX models trained w/ GPT-J parallel residuals.
See https://github.com/EleutherAI/gpt-neox/pull/481 for further details."""

# set all config values.
hf_config = GPTNeoXConfig(
vocab_size=args.padded_vocab_size,
Expand Down

0 comments on commit 0accac6

Please sign in to comment.