Skip to content

Commit

Permalink
Merge pull request EleutherAI#761 from EleutherAI/fix_tied_ln
Browse files Browse the repository at this point in the history
Fix behavior of `gpt_j_tied` synced LayerNorms
  • Loading branch information
StellaAthena authored Dec 27, 2022
2 parents 4afbaeb + 9c67324 commit 679ebca
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
1 change: 0 additions & 1 deletion configs/20B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"rotary_pct": 0.25,
"no-weight-tying": true,
"gpt_j_residual": true,
"gpt_j_tied": true,
"output_layer_parallelism": "column",
"scaled-upper-triang-masked-softmax-fusion": true,
"bias-gelu-fusion": true,
Expand Down
16 changes: 13 additions & 3 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 9fd1b6b
Default = 27e56e3

current git hash of repository

Expand Down Expand Up @@ -923,6 +923,15 @@ Text Generation arguments



- **prompt_end**: str

Default =


a single prompt's end. Defaults to newline



- **sample_input_file**: str

Default = None
Expand Down Expand Up @@ -958,7 +967,7 @@ Text Generation arguments

- **eval_results_prefix**: str

Default =
Default =

prefix to which to save evaluation results - final fp will be {eval_results_prefix}_eval_results_yy-mm-dd-HH-MM.json

Expand Down Expand Up @@ -1572,7 +1581,7 @@ Args for deepspeed config

Default = None





Expand Down Expand Up @@ -1706,3 +1715,4 @@ Args for deepspeed runner (deepspeed.launcher.runner).
Default = None

Adds a `--comment` to the DeepSpeed launch command. In DeeperSpeed this is passed on to the SlurmLauncher as well. Sometime necessary for cluster rules, or so I've heard.

4 changes: 2 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,10 +667,10 @@ def forward(self, x, attention_mask, layer_past=None):
residual = x
# applies the correct normalization depending on if the norms are tied
if self.gpt_j_tied:
x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x)
else:
x = self.input_layernorm(x)
x1, x2 = x, x
else:
x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x)

# attention operator
attention_output, attention_bias = self.attention(
Expand Down

0 comments on commit 679ebca

Please sign in to comment.