Skip to content

Commit

Permalink
Add QK Normalization (#1100)
Browse files Browse the repository at this point in the history
* add qk normalization

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

---------

Co-authored-by: github-actions <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
3 people committed Dec 22, 2023
1 parent 9283eff commit f161245
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
10 changes: 9 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = a279fc8
Default = 1fc0521

current git hash of repository

Expand Down Expand Up @@ -261,6 +261,14 @@ Model Arguments



- **use_qk_layernorm**: bool

Default = False

Use QK Normalization



- **layernorm_epsilon**: float

Default = 1e-05
Expand Down
15 changes: 15 additions & 0 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ def __init__(
neox_args.num_attention_heads, world_size
)
self.pos_emb = neox_args.pos_emb
self.use_qk_layernorm = neox_args.use_qk_layernorm
if self.use_qk_layernorm:
norm, eps = get_norm(neox_args)
self.qk_layernorm = norm(
[
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
],
eps=eps,
)

# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
Expand Down Expand Up @@ -639,6 +649,11 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
mixed_x_layer, 3
)

# QK Normalization https://arxiv.org/abs/2302.05442
if self.use_qk_layernorm:
query_layer = self.qk_layernorm(query_layer)
key_layer = self.qk_layernorm(key_layer)

if exists(self.rotary_emb):
if exists(self.rotary_ndims):
# partial rotary
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ class NeoXArgsModel(NeoXArgsTemplate):
Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm".
"""

use_qk_layernorm: bool = False
"""
Use QK Normalization
"""

layernorm_epsilon: float = 1.0e-5
"""
Layer norm epsilon.
Expand Down

0 comments on commit f161245

Please sign in to comment.