Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[muP] Rework #1087

Draft
wants to merge 107 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
0d921f7
changed ordering for setting up norm_factor
lintangsutawika Dec 1, 2023
abee54d
Update NeoXArgs docs automatically
invalid-email-address Dec 1, 2023
a08c3ef
updated muP args to the minimum required
lintangsutawika Dec 1, 2023
c35e830
calculate m_width
lintangsutawika Dec 1, 2023
2807e52
Merge branch 'main' of https://github.com/EleutherAI/gpt-neox into re…
lintangsutawika Dec 1, 2023
2d127df
Merge branch 'rework-mup' of https://github.com/EleutherAI/gpt-neox i…
lintangsutawika Dec 1, 2023
81fdc4d
Update NeoXArgs docs automatically
invalid-email-address Dec 1, 2023
7d6b246
changed ordering for setting up norm_factor
lintangsutawika Dec 1, 2023
a0d1929
updated muP args to the minimum required
lintangsutawika Dec 1, 2023
d63b3b8
calculate m_width
lintangsutawika Dec 1, 2023
9be82fe
Update NeoXArgs docs automatically
invalid-email-address Dec 1, 2023
66214d9
removed redundant line
lintangsutawika Dec 1, 2023
17b7183
removed redundant lines
lintangsutawika Dec 1, 2023
a6bad07
Update NeoXArgs docs automatically
invalid-email-address Dec 1, 2023
63984bd
removed redundant lines
lintangsutawika Dec 1, 2023
02687a8
Merge branch 'rework-mup' of https://github.com/EleutherAI/gpt-neox i…
lintangsutawika Dec 1, 2023
11114e2
Update NeoXArgs docs automatically
invalid-email-address Dec 1, 2023
05c4de3
modify init with mup
lintangsutawika Dec 1, 2023
71a91e4
divide logits by the m_width
lintangsutawika Dec 1, 2023
99c8ce0
moved position of mup parameters being processed
lintangsutawika Dec 1, 2023
b253ab6
add note
lintangsutawika Dec 1, 2023
1919499
made param groups to hold flag for mup scaling
lintangsutawika Dec 6, 2023
17678e0
lr scale
lintangsutawika Dec 6, 2023
2bd5ae6
update config
lintangsutawika Dec 6, 2023
6642291
adjust process of mup variables
lintangsutawika Dec 6, 2023
8be6c66
remove calling save_base_shapes
lintangsutawika Dec 18, 2023
c9fb18b
lr adjustments is done in train_step to address lr being reset due to…
lintangsutawika Dec 18, 2023
795371c
lr scaling for mup is moved here instead
lintangsutawika Dec 18, 2023
087beee
removed mup usage for coord check
lintangsutawika Jan 3, 2024
16d04b1
merged with main
lintangsutawika Jan 3, 2024
e7b7bf6
latest update on coord check implementation
lintangsutawika Jan 24, 2024
8dea9ce
fix merge conflict
lintangsutawika Feb 2, 2024
3664eba
changed `mup_m_width` to `mup_width_multiplier`
lintangsutawika Feb 2, 2024
6a46247
fixed notations
lintangsutawika Feb 2, 2024
7439f9a
correct scale
lintangsutawika Feb 2, 2024
5b2d31c
m_emb * embed(X)
lintangsutawika Feb 2, 2024
98caa82
removed mup rescale in the layers
lintangsutawika Feb 2, 2024
5c99637
removed mup rescale in the layers
lintangsutawika Feb 2, 2024
a636f06
adjust mup_m_emb to mup_embedding_multiplier
lintangsutawika Feb 2, 2024
39190c5
add multiplier mup_output_multiplier
lintangsutawika Feb 20, 2024
2489cc0
reorder model loading
lintangsutawika Feb 20, 2024
23b8776
removed comments
lintangsutawika Feb 20, 2024
10e935e
removed comments
lintangsutawika Feb 20, 2024
a0aca99
implement full process
lintangsutawika Feb 20, 2024
9472b35
set neox_args.iteration to 0 for coord_check mode
lintangsutawika Feb 21, 2024
5c5f2df
move mup_width_multiplier init
lintangsutawika Feb 21, 2024
7eca3e7
mup_coord_check returns 2 df
lintangsutawika Feb 21, 2024
c9a3a65
can run
lintangsutawika Feb 21, 2024
a7877d4
remove commehts
lintangsutawika Feb 22, 2024
bd9d399
add hooks
lintangsutawika Feb 22, 2024
fe180d3
remove comments
lintangsutawika Feb 22, 2024
b240c19
uncomment activation data
lintangsutawika Feb 22, 2024
93b4241
plot coords
lintangsutawika Feb 22, 2024
d4899fc
removed variables, add way to plot only from rank 0
lintangsutawika Feb 22, 2024
f589e29
changed key name in dict
lintangsutawika Feb 22, 2024
8261e0d
remove print
lintangsutawika Feb 22, 2024
25aa786
fix how width_multiplier is applied
lintangsutawika Feb 22, 2024
4d246a1
updated plot config
lintangsutawika Feb 22, 2024
84c5380
update files
lintangsutawika Feb 26, 2024
b2f1101
Merge branch 'main' into rework-mup
lintangsutawika Feb 26, 2024
42d4cde
Update NeoXArgs docs automatically
invalid-email-address Feb 26, 2024
4c477d5
init function, add input embedding different initialization
lintangsutawika Feb 27, 2024
64dc4c5
Merge branch 'rework-mup' of https://github.com/EleutherAI/gpt-neox i…
lintangsutawika Feb 27, 2024
65c103e
changeoutput layer to normal
lintangsutawika Feb 27, 2024
08b5d40
change from mean to std
lintangsutawika Feb 27, 2024
2ca94a8
double attention head for every hidden size doubled
lintangsutawika Feb 27, 2024
7483246
Merge branch 'main' into rework-mup
lintangsutawika Feb 27, 2024
497485c
Update NeoXArgs docs automatically
invalid-email-address Feb 27, 2024
34fb7ca
added args
lintangsutawika Feb 27, 2024
2d53f1f
simplify coordcheck
lintangsutawika Feb 27, 2024
7897610
seperate sp and mup configs
lintangsutawika Feb 27, 2024
4f39209
perform coordcheck for sp and mup seperately
lintangsutawika Feb 27, 2024
5f84a3f
Update NeoXArgs docs automatically
invalid-email-address Feb 27, 2024
479b854
update
lintangsutawika Feb 28, 2024
21a7e32
update how params are sorted
lintangsutawika Feb 28, 2024
bb2e0c9
remove unused comments
lintangsutawika Feb 28, 2024
bf1ce06
adjust
lintangsutawika Feb 29, 2024
50a3dba
simplify
lintangsutawika Feb 29, 2024
c4c1660
fix mup embedding multiplier
lintangsutawika Feb 29, 2024
1c35911
embeddingpipe fix init
lintangsutawika Feb 29, 2024
84be4d4
changed how manual seed is loaded
lintangsutawika Feb 29, 2024
fbb4daf
removed musgd and other changces
lintangsutawika Feb 29, 2024
fa142ff
update config
lintangsutawika Feb 29, 2024
ad2336f
fixed how params are sorted
lintangsutawika Feb 29, 2024
fe73bc3
update how seed is computed
lintangsutawika Feb 29, 2024
a3bd44c
update to follow pre-commit format
lintangsutawika Feb 29, 2024
56b6c9b
update from main
lintangsutawika Feb 29, 2024
2365fd5
update
lintangsutawika Feb 29, 2024
e8639a0
Update NeoXArgs docs automatically
invalid-email-address Feb 29, 2024
47e1438
fix lr weighting
lintangsutawika Mar 5, 2024
a064f9b
hard set to 1.0 if neox_args.use_mup is false
lintangsutawika Mar 5, 2024
b0da27a
Merge branch 'main' into rework-mup
Quentin-Anthony Apr 21, 2024
6fe55f4
Update NeoXArgs docs automatically
invalid-email-address Apr 21, 2024
8bf8bcd
add new parameters
lintangsutawika May 2, 2024
7f0b033
add parameter checks
lintangsutawika May 2, 2024
f802869
updates to argument processing for mup
lintangsutawika May 2, 2024
cc71104
add data save and descriptions being printed
lintangsutawika May 2, 2024
c8feb39
update mup
lintangsutawika May 2, 2024
b6b3a02
update seed
lintangsutawika May 2, 2024
847e892
remove print text
lintangsutawika May 2, 2024
1b0027c
fixed kv
lintangsutawika May 2, 2024
055596f
update
lintangsutawika May 2, 2024
fabb45b
update dewcriptions being printed
lintangsutawika May 2, 2024
5ccf693
removed unused lines
lintangsutawika May 2, 2024
9dd583b
Merge branch 'rework-mup' of https://github.com/EleutherAI/gpt-neox i…
lintangsutawika May 2, 2024
6a8ad71
Merge branch 'main' into rework-mup
lintangsutawika May 2, 2024
485cad4
Update NeoXArgs docs automatically
invalid-email-address May 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 17 additions & 35 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ Logging Arguments

- **git_hash**: str

<<<<<<< HEAD
lintangsutawika marked this conversation as resolved.
Show resolved Hide resolved
Default = 02687a8
=======
Default = 31cb364
>>>>>>> e5a7ea71e96eeada636c9612036dc85e886d973d

current git hash of repository

Expand Down Expand Up @@ -460,6 +464,7 @@ Model Arguments
Default = 0.02

Standard deviation of the zero mean normal distribution used for weight initialization.
When using muP this is the base std



Expand Down Expand Up @@ -671,6 +676,7 @@ Optimizer Arguments
Default = None

Max Learning rate during training
When using muP, this is the base lr



Expand Down Expand Up @@ -1529,7 +1535,7 @@ Training Arguments

Default = False

Whether to use Microsoft's Mup https://github.com/microsoft/mup
Whether to use muP



Expand Down Expand Up @@ -1557,52 +1563,28 @@ Training Arguments



- **mup_init_scale**: float
- **mup_emb**: int

Default = 1.0

Initialization scale: All the parameters are multiplied by this value



- **mup_attn_temp**: float

Default = 1.0

Attention temperature: Reciprocal of the multiplier applied to the input to attention softmax



- **mup_output_temp**: float

Default = 1.0

Output temperature: Reciprocal of the multiplier applied to the input to softmax that
produces the distribution over output tokens.



- **mup_embedding_mult**: float

Default = 1.0
Default = 1

Scalar by which we multiply the output of the embedding layer
Embedding output multiplier



- **mup_rp_embedding_mult**: float
- **mup_m_width**: int
lintangsutawika marked this conversation as resolved.
Show resolved Hide resolved

Default = 1.0
Default = 1

Scalar by which we multiply vectors representing relative position
Manually set the layer width multiplier (d_model/d_model,base)



- **mup_width_scale**: int
- **mup_d_model_base**: int

Default = 2
Default = 64
lintangsutawika marked this conversation as resolved.
Show resolved Hide resolved

What to scale width by when creating the delta model for mup
d_model,base
Proxy (base) model's layer width



Expand Down
6 changes: 4 additions & 2 deletions megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False,
use_mup=False,
mup_m_width=1,
):

# Class values.
Expand All @@ -51,6 +52,7 @@ def __init__(
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
self.use_mup = use_mup
self.mup_m_width = mup_m_width
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, (
"both override and " "use-checkpoint are set."
Expand Down Expand Up @@ -95,8 +97,8 @@ def step(self, step_num=None):
self.num_iters = step_num
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
if self.use_mup and "width_mult" in group:
group["lr"] = new_lr / group["width_mult"]
if self.use_mup and ("lr_adjust" in group) and group["lr_adjust"] is True:
group["lr"] = new_lr / self.mup_m_width
lintangsutawika marked this conversation as resolved.
Show resolved Hide resolved
else:
group["lr"] = new_lr

Expand Down
10 changes: 2 additions & 8 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def init_specs(self):
# Embedding layer
# input will be (input_ids, position_ids, attention_mask)

# TODO Initilized weights here should not be divided by m_width
if weight_tying:
self.specs.append(
TiedLayerSpec(
Expand Down Expand Up @@ -268,16 +269,9 @@ def init_specs(self):

def _logits_helper(embedding, lm_output):
"""Just a wrapper to massage inputs/outputs from pipeline."""
if self.neox_args.use_mup:
# Since we're using pipeline parallelism, we can't directly use MuReadout. Instead, use this workaround that does the same thing as MuReadout.
# https://github.com/microsoft/mup/issues/6#issuecomment-1082156274
lm_output = (
lm_output
/ self.tied_modules.embed.word_embeddings.weight.infshape.width_mult()
)

logits = parallel_lm_logits(
lm_output, embedding.word_embeddings_weight, self.parallel_output
lm_output, embedding.word_embeddings_weight, self.parallel_output, self.neox_args
)
return logits

Expand Down
117 changes: 42 additions & 75 deletions megatron/model/init_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,22 @@

import torch

try:
import mup
except ImportError:
pass


def init_method_normal(sigma, use_mup_outer=False, mup_init_scale=1.0):
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
lintangsutawika marked this conversation as resolved.
Show resolved Hide resolved

def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=sigma)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)

return init_


def scaled_init_method_normal(
sigma, num_layers, use_mup_outer=False, mup_init_scale=1.0
):
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)

def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=std)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)

return init_

Expand Down Expand Up @@ -87,12 +68,12 @@ def _orthogonal(tensor, gain=1):
return tensor


def orthogonal_init_method(n_layers=1, use_mup=False, mup_init_scale=1.0):
def orthogonal_init_method(n_layers=1, mup_m_width=1.0):
"""Fills the input Tensor with a (semi) orthogonal matrix, as described in
Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe, A. et al. (2013)
Optionally scaling by number of layers possible, as introduced in OBST - Nestler et. al. (2021, to be released)"""

if use_mup:
if mup_m_width != 1:
raise ValueError(
"Orthogonal init needs to be patched to support mup. Disable mup or use a different init method to avoid this error"
)
Expand All @@ -103,105 +84,91 @@ def init_(tensor):
return init_


def xavier_uniform_init_method(use_mup_outer=False, mup_init_scale=1.0):
def xavier_uniform_init_method(mup_m_width=1.0):
"""Fills the input Tensor with values according to the method described in Understanding the difficulty of
training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution."""

def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.xavier_uniform_(tensor)
def init_(tensor, mup_m_width=mup_m_width):
init_weight = torch.nn.init.xavier_uniform_(tensor)
if mup_m_width != 1:
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.xavier_uniform_(tensor)
init_weight.div_(mup_m_width)
lintangsutawika marked this conversation as resolved.
Show resolved Hide resolved
return init_weight

return init_


def xavier_normal_init_method(use_mup_outer=False, mup_init_scale=1.0):
def xavier_normal_init_method(mup_m_width=1.0):
"""Fills the input Tensor with values according to the method described in Understanding the difficulty of
training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution."""

def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.xavier_normal_(tensor)
def init_(tensor, mup_m_width=mup_m_width):
init_weight = torch.nn.init.xavier_normal_(tensor)
if mup_m_width != 1:
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.xavier_normal_(tensor)
init_weight.div_(mup_m_width)
lintangsutawika marked this conversation as resolved.
Show resolved Hide resolved
return init_weight

return init_


def small_init_init_method(dim, use_mup_outer=False, mup_init_scale=1.0):
def small_init_init_method(dim, mup_m_width=1.0):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
std = math.sqrt(2 / (5 * dim))

def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=std)
def init_(tensor, mup_m_width=mup_m_width):
init_weight = torch.nn.init.normal_(tensor, mean=0.0, std=std)
if mup_m_width != 1:
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
init_weight.div_(mup_m_width)
lintangsutawika marked this conversation as resolved.
Show resolved Hide resolved
return init_weight

return init_


def wang_init_method(n_layers, dim, use_mup_outer=False, mup_init_scale=1.0):
def wang_init_method(n_layers, dim, mup_m_width=1.0):
std = 2 / n_layers / math.sqrt(dim)

def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=std)
def init_(tensor, mup_m_width=mup_m_width):
init_weight = torch.nn.init.normal_(tensor, mean=0.0, std=std)
if mup_m_width != 1:
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=std)

init_weight.div_(mup_m_width)
lintangsutawika marked this conversation as resolved.
Show resolved Hide resolved
return init_weight

return init_


def get_init_methods(args):

if args.use_mup:
try:
import mup
except ModuleNotFoundError:
print("Please install mup https://github.com/microsoft/mup")
raise Exception

def _get(name):
if name == "normal":
return init_method_normal(
args.init_method_std, args.use_mup, args.mup_init_scale
sigma=args.init_method_std/math.sqrt(args.mup_m_width)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you are going to define the muP init_std adjustment here then you need to remove all the init_weight.div_ calls so you dont double adjust

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm split on what the best interface would be. For init_method_normal and scaled_init_method_normal I opted to not scale in the function but for other init functions I did.

)
elif name == "scaled_normal":
return scaled_init_method_normal(
args.init_method_std, args.num_layers, args.use_mup, args.mup_init_scale
sigma=args.init_method_std/math.sqrt(args.mup_m_width),
num_layers=args.num_layers
)
elif name == "orthogonal":
return orthogonal_init_method(args.use_mup, args.mup_init_scale)
return orthogonal_init_method(args.mup_m_width)
elif name == "scaled_orthogonal":
return orthogonal_init_method(
args.num_layers, args.use_mup, args.mup_init_scale
args.num_layers, args.mup_m_width
)
elif name == "xavier_uniform":
return xavier_uniform_init_method(args.use_mup, args.mup_init_scale)
return xavier_uniform_init_method(args.mup_m_width)
elif name == "xavier_normal":
return xavier_normal_init_method(args.use_mup, args.mup_init_scale)
return xavier_normal_init_method(args.mup_m_width)
elif name == "wang_init":
return wang_init_method(
args.num_layers, args.hidden_size, args.use_mup, args.mup_init_scale
args.num_layers, args.hidden_size, args.mup_m_width
)
elif name == "small_init":
return small_init_init_method(
args.hidden_size, args.use_mup, args.mup_init_scale
args.hidden_size, args.mup_m_width
)
else:
raise NotImplementedError(f"Unknown init method {name}")
Expand Down
15 changes: 9 additions & 6 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,13 @@ def __init__(
)

coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = max(1, self.layer_number)
self.norm_factor *= coeff

if neox_args.use_mup:
self.norm_factor = self.hidden_size_per_attention_head
else:
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = max(1, self.layer_number)
self.norm_factor *= coeff

self.rpe = rpe

Expand Down Expand Up @@ -960,7 +960,7 @@ def forward(self, args):
return self.norm(args)


def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None, args=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_model_parallel_region(input_)
Expand All @@ -971,6 +971,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=Non
else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)

if args is not None and args.use_mup:
logits_parallel /= args.mup_m_width
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may also want to multiply by some tunable scalar value here (like we did in BLTM)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that mup_m_width was the tunable scalar for the logits here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the part where Y_logits = W_unemb X/m_width?


# Gather if needed.
if parallel_output:
return logits_parallel
Expand Down