Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[muP] Rework #1087

Draft
wants to merge 107 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
0d921f7
changed ordering for setting up norm_factor
lintangsutawika Dec 1, 2023
abee54d
Update NeoXArgs docs automatically
invalid-email-address Dec 1, 2023
a08c3ef
updated muP args to the minimum required
lintangsutawika Dec 1, 2023
c35e830
calculate m_width
lintangsutawika Dec 1, 2023
2807e52
Merge branch 'main' of https://github.com/EleutherAI/gpt-neox into re…
lintangsutawika Dec 1, 2023
2d127df
Merge branch 'rework-mup' of https://github.com/EleutherAI/gpt-neox i…
lintangsutawika Dec 1, 2023
81fdc4d
Update NeoXArgs docs automatically
invalid-email-address Dec 1, 2023
7d6b246
changed ordering for setting up norm_factor
lintangsutawika Dec 1, 2023
a0d1929
updated muP args to the minimum required
lintangsutawika Dec 1, 2023
d63b3b8
calculate m_width
lintangsutawika Dec 1, 2023
9be82fe
Update NeoXArgs docs automatically
invalid-email-address Dec 1, 2023
66214d9
removed redundant line
lintangsutawika Dec 1, 2023
17b7183
removed redundant lines
lintangsutawika Dec 1, 2023
a6bad07
Update NeoXArgs docs automatically
invalid-email-address Dec 1, 2023
63984bd
removed redundant lines
lintangsutawika Dec 1, 2023
02687a8
Merge branch 'rework-mup' of https://github.com/EleutherAI/gpt-neox i…
lintangsutawika Dec 1, 2023
11114e2
Update NeoXArgs docs automatically
invalid-email-address Dec 1, 2023
05c4de3
modify init with mup
lintangsutawika Dec 1, 2023
71a91e4
divide logits by the m_width
lintangsutawika Dec 1, 2023
99c8ce0
moved position of mup parameters being processed
lintangsutawika Dec 1, 2023
b253ab6
add note
lintangsutawika Dec 1, 2023
1919499
made param groups to hold flag for mup scaling
lintangsutawika Dec 6, 2023
17678e0
lr scale
lintangsutawika Dec 6, 2023
2bd5ae6
update config
lintangsutawika Dec 6, 2023
6642291
adjust process of mup variables
lintangsutawika Dec 6, 2023
8be6c66
remove calling save_base_shapes
lintangsutawika Dec 18, 2023
c9fb18b
lr adjustments is done in train_step to address lr being reset due to…
lintangsutawika Dec 18, 2023
795371c
lr scaling for mup is moved here instead
lintangsutawika Dec 18, 2023
087beee
removed mup usage for coord check
lintangsutawika Jan 3, 2024
16d04b1
merged with main
lintangsutawika Jan 3, 2024
e7b7bf6
latest update on coord check implementation
lintangsutawika Jan 24, 2024
8dea9ce
fix merge conflict
lintangsutawika Feb 2, 2024
3664eba
changed `mup_m_width` to `mup_width_multiplier`
lintangsutawika Feb 2, 2024
6a46247
fixed notations
lintangsutawika Feb 2, 2024
7439f9a
correct scale
lintangsutawika Feb 2, 2024
5b2d31c
m_emb * embed(X)
lintangsutawika Feb 2, 2024
98caa82
removed mup rescale in the layers
lintangsutawika Feb 2, 2024
5c99637
removed mup rescale in the layers
lintangsutawika Feb 2, 2024
a636f06
adjust mup_m_emb to mup_embedding_multiplier
lintangsutawika Feb 2, 2024
39190c5
add multiplier mup_output_multiplier
lintangsutawika Feb 20, 2024
2489cc0
reorder model loading
lintangsutawika Feb 20, 2024
23b8776
removed comments
lintangsutawika Feb 20, 2024
10e935e
removed comments
lintangsutawika Feb 20, 2024
a0aca99
implement full process
lintangsutawika Feb 20, 2024
9472b35
set neox_args.iteration to 0 for coord_check mode
lintangsutawika Feb 21, 2024
5c5f2df
move mup_width_multiplier init
lintangsutawika Feb 21, 2024
7eca3e7
mup_coord_check returns 2 df
lintangsutawika Feb 21, 2024
c9a3a65
can run
lintangsutawika Feb 21, 2024
a7877d4
remove commehts
lintangsutawika Feb 22, 2024
bd9d399
add hooks
lintangsutawika Feb 22, 2024
fe180d3
remove comments
lintangsutawika Feb 22, 2024
b240c19
uncomment activation data
lintangsutawika Feb 22, 2024
93b4241
plot coords
lintangsutawika Feb 22, 2024
d4899fc
removed variables, add way to plot only from rank 0
lintangsutawika Feb 22, 2024
f589e29
changed key name in dict
lintangsutawika Feb 22, 2024
8261e0d
remove print
lintangsutawika Feb 22, 2024
25aa786
fix how width_multiplier is applied
lintangsutawika Feb 22, 2024
4d246a1
updated plot config
lintangsutawika Feb 22, 2024
84c5380
update files
lintangsutawika Feb 26, 2024
b2f1101
Merge branch 'main' into rework-mup
lintangsutawika Feb 26, 2024
42d4cde
Update NeoXArgs docs automatically
invalid-email-address Feb 26, 2024
4c477d5
init function, add input embedding different initialization
lintangsutawika Feb 27, 2024
64dc4c5
Merge branch 'rework-mup' of https://github.com/EleutherAI/gpt-neox i…
lintangsutawika Feb 27, 2024
65c103e
changeoutput layer to normal
lintangsutawika Feb 27, 2024
08b5d40
change from mean to std
lintangsutawika Feb 27, 2024
2ca94a8
double attention head for every hidden size doubled
lintangsutawika Feb 27, 2024
7483246
Merge branch 'main' into rework-mup
lintangsutawika Feb 27, 2024
497485c
Update NeoXArgs docs automatically
invalid-email-address Feb 27, 2024
34fb7ca
added args
lintangsutawika Feb 27, 2024
2d53f1f
simplify coordcheck
lintangsutawika Feb 27, 2024
7897610
seperate sp and mup configs
lintangsutawika Feb 27, 2024
4f39209
perform coordcheck for sp and mup seperately
lintangsutawika Feb 27, 2024
5f84a3f
Update NeoXArgs docs automatically
invalid-email-address Feb 27, 2024
479b854
update
lintangsutawika Feb 28, 2024
21a7e32
update how params are sorted
lintangsutawika Feb 28, 2024
bb2e0c9
remove unused comments
lintangsutawika Feb 28, 2024
bf1ce06
adjust
lintangsutawika Feb 29, 2024
50a3dba
simplify
lintangsutawika Feb 29, 2024
c4c1660
fix mup embedding multiplier
lintangsutawika Feb 29, 2024
1c35911
embeddingpipe fix init
lintangsutawika Feb 29, 2024
84be4d4
changed how manual seed is loaded
lintangsutawika Feb 29, 2024
fbb4daf
removed musgd and other changces
lintangsutawika Feb 29, 2024
fa142ff
update config
lintangsutawika Feb 29, 2024
ad2336f
fixed how params are sorted
lintangsutawika Feb 29, 2024
fe73bc3
update how seed is computed
lintangsutawika Feb 29, 2024
a3bd44c
update to follow pre-commit format
lintangsutawika Feb 29, 2024
56b6c9b
update from main
lintangsutawika Feb 29, 2024
2365fd5
update
lintangsutawika Feb 29, 2024
e8639a0
Update NeoXArgs docs automatically
invalid-email-address Feb 29, 2024
47e1438
fix lr weighting
lintangsutawika Mar 5, 2024
a064f9b
hard set to 1.0 if neox_args.use_mup is false
lintangsutawika Mar 5, 2024
b0da27a
Merge branch 'main' into rework-mup
Quentin-Anthony Apr 21, 2024
6fe55f4
Update NeoXArgs docs automatically
invalid-email-address Apr 21, 2024
8bf8bcd
add new parameters
lintangsutawika May 2, 2024
7f0b033
add parameter checks
lintangsutawika May 2, 2024
f802869
updates to argument processing for mup
lintangsutawika May 2, 2024
cc71104
add data save and descriptions being printed
lintangsutawika May 2, 2024
c8feb39
update mup
lintangsutawika May 2, 2024
b6b3a02
update seed
lintangsutawika May 2, 2024
847e892
remove print text
lintangsutawika May 2, 2024
1b0027c
fixed kv
lintangsutawika May 2, 2024
055596f
update
lintangsutawika May 2, 2024
fabb45b
update dewcriptions being printed
lintangsutawika May 2, 2024
5ccf693
removed unused lines
lintangsutawika May 2, 2024
9dd583b
Merge branch 'rework-mup' of https://github.com/EleutherAI/gpt-neox i…
lintangsutawika May 2, 2024
6a8ad71
Merge branch 'main' into rework-mup
lintangsutawika May 2, 2024
485cad4
Update NeoXArgs docs automatically
invalid-email-address May 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
changed mup_m_width to mup_width_multiplier
  • Loading branch information
lintangsutawika committed Feb 2, 2024
commit 3664ebab5a0eb1614a1883c313584a1d623f5256
4 changes: 2 additions & 2 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,7 @@ Training Arguments



- **mup_m_width**: int
- **mup_width_multiplier**: int

Default = 1

Expand All @@ -1577,7 +1577,7 @@ Training Arguments

- **mup_d_model_base**: int

Default = 64
Default = 256

d_model,base
Proxy (base) model's layer width
Expand Down
6 changes: 3 additions & 3 deletions megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False,
use_mup=False,
mup_m_width=1,
mup_width_multiplier=1,
):

# Class values.
Expand All @@ -52,7 +52,7 @@ def __init__(
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
self.use_mup = use_mup
self.mup_m_width = mup_m_width
self.mup_width_multiplier = mup_width_multiplier
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, (
"both override and " "use-checkpoint are set."
Expand Down Expand Up @@ -98,7 +98,7 @@ def step(self, step_num=None):
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
if self.use_mup and ("lr_adjust" in group) and group["lr_adjust"] is True:
group["lr"] = new_lr / self.mup_m_width
group["lr"] = new_lr / self.mup_width_multiplier
else:
group["lr"] = new_lr

Expand Down
52 changes: 26 additions & 26 deletions megatron/model/init_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ def _orthogonal(tensor, gain=1):
return tensor


def orthogonal_init_method(n_layers=1, mup_m_width=1.0):
def orthogonal_init_method(n_layers=1, mup_width_multiplier=1.0):
"""Fills the input Tensor with a (semi) orthogonal matrix, as described in
Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe, A. et al. (2013)
Optionally scaling by number of layers possible, as introduced in OBST - Nestler et. al. (2021, to be released)"""

if mup_m_width != 1:
if mup_width_multiplier != 1:
raise ValueError(
"Orthogonal init needs to be patched to support mup. Disable mup or use a different init method to avoid this error"
)
Expand All @@ -84,57 +84,57 @@ def init_(tensor):
return init_


def xavier_uniform_init_method(mup_m_width=1.0):
def xavier_uniform_init_method(mup_width_multiplier=1.0):
"""Fills the input Tensor with values according to the method described in Understanding the difficulty of
training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution."""

def init_(tensor, mup_m_width=mup_m_width):
def init_(tensor, mup_width_multiplier=mup_width_multiplier):
init_weight = torch.nn.init.xavier_uniform_(tensor)
if mup_m_width != 1:
if mup_width_multiplier != 1:
with torch.no_grad():
init_weight.div_(mup_m_width)
init_weight.div_(mup_width_multiplier)
return init_weight

return init_


def xavier_normal_init_method(mup_m_width=1.0):
def xavier_normal_init_method(mup_width_multiplier=1.0):
"""Fills the input Tensor with values according to the method described in Understanding the difficulty of
training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution."""

def init_(tensor, mup_m_width=mup_m_width):
def init_(tensor, mup_width_multiplier=mup_width_multiplier):
init_weight = torch.nn.init.xavier_normal_(tensor)
if mup_m_width != 1:
if mup_width_multiplier != 1:
with torch.no_grad():
init_weight.div_(mup_m_width)
init_weight.div_(mup_width_multiplier)
return init_weight

return init_


def small_init_init_method(dim, mup_m_width=1.0):
def small_init_init_method(dim, mup_width_multiplier=1.0):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
std = math.sqrt(2 / (5 * dim))

def init_(tensor, mup_m_width=mup_m_width):
def init_(tensor, mup_width_multiplier=mup_width_multiplier):
init_weight = torch.nn.init.normal_(tensor, mean=0.0, std=std)
if mup_m_width != 1:
if mup_width_multiplier != 1:
with torch.no_grad():
init_weight.div_(mup_m_width)
init_weight.div_(mup_width_multiplier)
return init_weight

return init_


def wang_init_method(n_layers, dim, mup_m_width=1.0):
def wang_init_method(n_layers, dim, mup_width_multiplier=1.0):
std = 2 / n_layers / math.sqrt(dim)

def init_(tensor, mup_m_width=mup_m_width):
def init_(tensor, mup_width_multiplier=mup_width_multiplier):
init_weight = torch.nn.init.normal_(tensor, mean=0.0, std=std)
if mup_m_width != 1:
if mup_width_multiplier != 1:
with torch.no_grad():
init_weight.div_(mup_m_width)
init_weight.div_(mup_width_multiplier)
return init_weight

return init_
Expand All @@ -145,30 +145,30 @@ def get_init_methods(args):
def _get(name):
if name == "normal":
return init_method_normal(
sigma=args.init_method_std/math.sqrt(args.mup_m_width)
sigma=args.init_method_std/math.sqrt(args.mup_width_multiplier)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During our call, we noted that this leads to a bug: Since the width multiplier is applied to all layers, this doesn't allow the embedding layer to be initialized differently from the transformer backbone layers (precisely: muP prescribes that layers who's input and output dimensions both scale with width need to have a sqrt(width) multiplying factor).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would recommend refactoring this code: Remove the muP width multiplier completely from the initialization methods code, and only take the initializer parameters in here (e.g., standard deviation). Then, when the initializers are used from various layers, adjust the initializer based on that particular layer's muP width adjustment requirements.

)
elif name == "scaled_normal":
return scaled_init_method_normal(
sigma=args.init_method_std/math.sqrt(args.mup_m_width),
sigma=args.init_method_std/math.sqrt(args.mup_width_multiplier),
num_layers=args.num_layers
)
elif name == "orthogonal":
return orthogonal_init_method(args.mup_m_width)
return orthogonal_init_method(args.mup_width_multiplier)
elif name == "scaled_orthogonal":
return orthogonal_init_method(
args.num_layers, args.mup_m_width
args.num_layers, args.mup_width_multiplier
)
elif name == "xavier_uniform":
return xavier_uniform_init_method(args.mup_m_width)
return xavier_uniform_init_method(args.mup_width_multiplier)
elif name == "xavier_normal":
return xavier_normal_init_method(args.mup_m_width)
return xavier_normal_init_method(args.mup_width_multiplier)
elif name == "wang_init":
return wang_init_method(
args.num_layers, args.hidden_size, args.mup_m_width
args.num_layers, args.hidden_size, args.mup_width_multiplier
)
elif name == "small_init":
return small_init_init_method(
args.hidden_size, args.mup_m_width
args.hidden_size, args.mup_width_multiplier
)
else:
raise NotImplementedError(f"Unknown init method {name}")
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=Non
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)

if args is not None and args.use_mup:
logits_parallel /= args.mup_m_width
logits_parallel /= args.mup_width_multiplier

# Gather if needed.
if parallel_output:
Expand Down
4 changes: 2 additions & 2 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def __init__(
self.stride = stride
self.mup_rescale_parameters = mup_rescale_parameters
self.use_mup = neox_args.use_mup
self.m_width = neox_args.mup_m_width
self.m_width = neox_args.mup_width_multiplier

# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
Expand Down Expand Up @@ -627,7 +627,7 @@ def __init__(
self.keep_master_weight_for_test = keep_master_weight_for_test
self.mup_rescale_parameters = mup_rescale_parameters
self.use_mup = neox_args.use_mup
self.m_width = neox_args.mup_m_width
self.m_width = neox_args.mup_width_multiplier

# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
Expand Down
2 changes: 1 addition & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,7 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Embedding output multiplier
"""

mup_m_width: float = None
mup_width_multiplier: float = None
"""
Manually set the layer width multiplier (d_model/d_model,base)
"""
Expand Down
8 changes: 4 additions & 4 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,9 @@ def get_model(neox_args, use_cache=False):
# neox_args.use_mup = False
if neox_args.use_mup:

if neox_args.mup_m_width == 1:
neox_args.mup_m_width = neox_args.hidden_size / neox_args.mup_d_model_base
print_rank_0(f"mup_m_width set to {neox_args.mup_m_width}")
if neox_args.mup_width_multiplier == 1:
neox_args.mup_width_multiplier = neox_args.hidden_size / neox_args.mup_d_model_base
print_rank_0(f"mup_width_multiplier set to {neox_args.mup_width_multiplier}")

# base_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}"

Expand Down Expand Up @@ -640,7 +640,7 @@ def get_learning_rate_scheduler(optimizer, neox_args):
use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler,
override_lr_scheduler=neox_args.override_lr_scheduler,
use_mup=neox_args.use_mup,
mup_m_width=neox_args.mup_m_width,
mup_width_multiplier=neox_args.mup_width_multiplier,
)

return lr_scheduler
Expand Down