Skip to content

Commit

Permalink
Make rotary freqs buffer non-persistent (#1168)
Browse files Browse the repository at this point in the history
* make inv_freq non-persistent by default

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

---------

Co-authored-by: github-actions <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
3 people committed Mar 4, 2024
1 parent 31cfe52 commit e109bf5
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 42 deletions.
14 changes: 13 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = b804ee8
Default = 2a3c4e1

current git hash of repository

Expand Down Expand Up @@ -601,6 +601,18 @@ Model Arguments



- **rotary_save_freqs_buffer**: bool

Default = False

Used to control whether the `inv_freqs` buffer in rotary embeddings
will be stored in checkpoints (persistent=True) or not.

Defaults to false, but is left configurable to maintain backward-compatibility
with GPT-NeoX checkpoints that were trained with this flag.



- **init_method**: typing.Literal['normal', 'scaled_normal', 'orthogonal', 'scaled_orthogonal', 'xavier_uniform', 'xavier_normal', 'wang_init', 'small_init']

Default = normal
Expand Down
8 changes: 5 additions & 3 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ def forward(self, x, seq_dim=1):


class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_seq_len, base=10000, precision=torch.half):
def __init__(
self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False
):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
Expand All @@ -53,7 +55,7 @@ def __init__(self, dim, max_seq_len, base=10000, precision=torch.half):
max_seq_len, precision, base
)

self.register_buffer("inv_freq", inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs)
self.cos_cached = cos_cached
self.sin_cached = sin_cached

Expand Down
1 change: 1 addition & 0 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def __init__(
base=neox_args.rotary_emb_base,
max_seq_len=neox_args.seq_length,
precision=neox_args.params_dtype,
save_inv_freqs=neox_args.rotary_save_freqs_buffer,
)
else:
self.rotary_emb = None
Expand Down
9 changes: 9 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ class NeoXArgsModel(NeoXArgsTemplate):
Base for rotary positional embedding
"""

rotary_save_freqs_buffer: bool = False
"""
Used to control whether the `inv_freqs` buffer in rotary embeddings
will be stored in checkpoints (persistent=True) or not.
Defaults to false, but is left configurable to maintain backward-compatibility
with GPT-NeoX checkpoints that were trained with this flag.
"""

init_method: Literal[
"normal",
"scaled_normal",
Expand Down
8 changes: 0 additions & 8 deletions tools/ckpts/convert_neox_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,14 +520,6 @@ def convert(
)

# Just take one
if "attention.rotary_emb.inv_freq" in hf_layer.state_dict():
state_dict["attention.rotary_emb.inv_freq"] = get_state(
loaded_tp_ranks,
"attention.rotary_emb.inv_freq",
layer_idx=layer_i + 2,
sequential=sequential,
)[0]

if "attention.bias" in hf_layer.state_dict():
state_dict["attention.bias"] = hf_layer.state_dict()["attention.bias"]
if "attention.masked_bias" in hf_layer.state_dict():
Expand Down
30 changes: 0 additions & 30 deletions tools/ckpts/convert_raw_llama_weights_to_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,6 @@ def permute_rotary(w):
pbar.update(1)

# Layers
if model_size == "7B":
rope_freqs = loaded[0]["layers.0.attention.inner_attention.rope.freqs"]
helper.del_loaded("layers.0.attention.inner_attention.rope.freqs")
elif "mistral" in model_size:
# mistral does not include rope freqs in the distributed checkpoint, unlike llama.
# rather than making this buffer always non-persistent on the NeoX side,
# just create and save it for Mistral.
base = 10000.0
rope_freqs = 1.0 / (
base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
)
else:
rope_freqs = loaded[0]["rope.freqs"]
helper.del_loaded("rope.freqs")
for layer_i in range(num_layers):

# Linear
Expand Down Expand Up @@ -293,7 +279,6 @@ def permute_rotary(w):
# Duplicated layers
"input_layernorm.scale": input_layernorm,
"post_attention_layernorm.scale": post_attention_layernorm,
"attention.rotary_emb.inv_freq": rope_freqs,
},
layer_i=layer_i + 2,
rank=out_rank,
Expand Down Expand Up @@ -409,20 +394,6 @@ def permute_rotary(w):
helper.del_loaded("output.weight")

# Layers
if model_size == "7B":
rope_freqs = loaded[0]["layers.0.attention.inner_attention.rope.freqs"]
helper.del_loaded("layers.0.attention.inner_attention.rope.freqs")
elif "mistral" in model_size:
# mistral does not include rope freqs in the distributed checkpoint, unlike llama.
# rather than making this buffer always non-persistent on the NeoX side,
# just create and save it for Mistral.
base = 10000.0
rope_freqs = 1.0 / (
base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
)
else:
rope_freqs = loaded[0]["rope.freqs"]
helper.del_loaded("rope.freqs")
for layer_i in range(num_layers):

# Linear
Expand Down Expand Up @@ -546,7 +517,6 @@ def permute_rotary(w):
# Duplicated layers
"input_layernorm.scale": input_layernorm,
"post_attention_layernorm.scale": post_attention_layernorm,
"attention.rotary_emb.inv_freq": rope_freqs,
},
layer_i=layer_i + 2,
rank=out_rank,
Expand Down

0 comments on commit e109bf5

Please sign in to comment.