Skip to content

Commit

Permalink
Mup Support (#704)
Browse files Browse the repository at this point in the history
* WIP: Add support for Maximal Update Parametrization and Hyperparameter Transfer (mup)

* Update to use MuAdam and MuSGD, fix minor errors

* Fix more errors with arguments

* Fix error caused by not calling to_sequential on delta model

* Update NeoXArgs docs automatically

* Address PR feedback

* Fix minor error

* Update NeoXArgs docs automatically

* Revert small.yml config

* Update NeoXArgs docs automatically

* Reinitialize weights using mup's replacements after set_base_shapes is called

* Update NeoXArgs docs automatically

* Implement rescale parameters on the output layer, adjust learning rate based on width

* Update NeoXArgs docs automatically

* Remove debug prints

* Update NeoXArgs docs automatically

* Add preliminary support for coord check (WIP: not yet functional in this commit)

* Update NeoXArgs docs automatically

* Add untracked file from last commit

* Update NeoXArgs docs automatically

* Update for coord check plots

* Update NeoXArgs docs automatically

* Add all but one (and a half) of the new hyperparameters from the zero-shot hp transfer paper

* Update NeoXArgs docs automatically

* Add last mup HP

* Add mup readme file

* Update NeoXArgs docs automatically

* Revert changes to configs/small.yml

* Update NeoXArgs docs automatically

* Update README-MUP.md

* Update NeoXArgs docs automatically

* Clean up code for PR

* Update NeoXArgs docs automatically

* Make mup import optional

* Update NeoXArgs docs automatically

* Revert "Update NeoXArgs docs automatically"

This reverts commit a7b97fd.

* Update NeoXArgs docs automatically

* Revert "Update NeoXArgs docs automatically"

This reverts commit 8161a56.

* Update NeoXArgs docs automatically

* Add neox arg for mup delta model width scale

* Update NeoXArgs docs automatically

Co-authored-by: Nick Sarkauskas <[email protected]>
Co-authored-by: github-actions <[email protected]>
Co-authored-by: Stella Biderman <[email protected]>
Co-authored-by: Quentin-Anthony <[email protected]>
  • Loading branch information
5 people committed Dec 10, 2022
1 parent 38f4ede commit 0535bfb
Show file tree
Hide file tree
Showing 11 changed files with 785 additions and 57 deletions.
49 changes: 49 additions & 0 deletions README-MUP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# How to use Mup (https://github.com/microsoft/mup)

## Add mup neox args to your config

```
# mup
"use-mup": true,
"save-base-shapes": false, # this only needs to be enabled once in order to generate the base-shapes-file on each rank
"base-shapes-file": "base-shapes", # load base shapes from this file
"coord-check": false, # generate coord check plots to verify mup's implementation in neox
# mup hp search
"mup-init-scale": 1.0,
"mup-attn-temp": 1.0,
"mup-output-temp": 1.0,
"mup-embedding-mult": 1.0,
"mup-rp-embedding-mult": 1.0,
```

## Generate base shapes

1. Set use-mup to true
2. Set save-base-shapes to true
3. Run once. gpt-neox will instantiate a base model and a delta model, then save one file per rank named <base-shapes-file>.<rank>. gpt-neox will exit immediately.
4. Set save-base-shapes to false

## Generate coord check plots (optional)

1. Keep use-mup true
2. Set coord-check to true
3. Run once. gpt-neox will output jpg images similar to https://github.com/microsoft/mutransformers/blob/main/README.md#coord-check. gpt-neox will exit immediately
4. Set coord-check to false

## Tune mup hyperparameters and LR

The values under `mup hp search` were added and correspond to appendix F.4 from https://arxiv.org/pdf/2203.03466.pdf. These and LR are tuned with a random search using the scaled-up config (tested with 6-7B.yml) but with hidden-size set to the value from the scaled-down config (small.yml).

## Transfer

With the best LR set and the best mup HPs set, revert the value of hidden-size in the scaled-up config and run again.
88 changes: 85 additions & 3 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 50acbdd
Default = 166c5b6

current git hash of repository

Expand Down Expand Up @@ -575,11 +575,12 @@ Optimizer Arguments



- **optimizer_type**: typing.Literal['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd']
- **optimizer_type**: typing.Literal['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd']

Default = adam

Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd]
Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd']
NOTE: sgd will use MuSGD from Mup. Mup must be enabled for this optimizer.



Expand Down Expand Up @@ -1414,6 +1415,87 @@ Training Arguments



- **use_mup**: bool

Default = False

Whether to use Microsoft's Mup https://github.com/microsoft/mup



- **coord_check**: bool

Default = False

Whether to generate a "coord check" plot to verify mup's implementation in neox



- **save_base_shapes**: bool

Default = False

Whether to save base shapes for mup. This will save the shapes to the path specified in base-shapes-file.



- **base_shapes_file**: str

Default = None

Path to the base shapes to save to/load from



- **mup_init_scale**: float

Default = 1.0

Initialization scale: All the parameters are multiplied by this value



- **mup_attn_temp**: float

Default = 1.0

Attention temperature: Reciprocal of the multiplier applied to the input to attention softmax



- **mup_output_temp**: float

Default = 1.0

Output temperature: Reciprocal of the multiplier applied to the input to softmax that
produces the distribution over output tokens.



- **mup_embedding_mult**: float

Default = 1.0

Scalar by which we multiply the output of the embedding layer



- **mup_rp_embedding_mult**: float

Default = 1.0

Scalar by which we multiply vectors representing relative position



- **mup_width_scale**: int

Default = 2

What to scale width by when creating the delta model for mup



## NeoXArgsDeepspeedConfig

Args for deepspeed config
Expand Down
7 changes: 6 additions & 1 deletion megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
min_lr=0.0,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False,
use_mup=False,
):

# Class values.
Expand All @@ -49,6 +50,7 @@ def __init__(
self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
self.use_mup = use_mup
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, (
"both override and " "use-checkpoint are set."
Expand Down Expand Up @@ -90,7 +92,10 @@ def step(self, step_num=None):
self.num_iters = step_num
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group["lr"] = new_lr
if self.use_mup and "width_mult" in group:
group["lr"] = new_lr / group["width_mult"]
else:
group['lr'] = new_lr

def state_dict(self):
state_dict = {
Expand Down
6 changes: 6 additions & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ def init_specs(self):

def _logits_helper(embedding, lm_output):
"""Just a wrapper to massage inputs/outputs from pipeline."""
if self.neox_args.use_mup:
# Since we're using pipeline parallelism, we can't directly use MuReadout. Instead, use this workaround that does the same thing as MuReadout.
# https://github.com/microsoft/mup/issues/6#issuecomment-1082156274
lm_output = lm_output / self.tied_modules.embed.word_embeddings.weight.infshape.width_mult()

logits = parallel_lm_logits(
lm_output, embedding.word_embeddings_weight, self.parallel_output
)
Expand Down Expand Up @@ -292,6 +297,7 @@ def _logits_helper(embedding, lm_output):
neox_args=self.neox_args,
init_method=self.init_method,
parallel_output=self.parallel_output,
is_last_layer=True,
)
)

Expand Down
107 changes: 80 additions & 27 deletions megatron/model/init_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,45 @@

import torch

try:
import mup
except ImportError:
pass

def init_method_normal(sigma):
def init_method_normal(sigma, use_mup_outer=False, mup_init_scale=1.0):
"""Init method based on N(0, sigma)."""

def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=sigma)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)

return init_


def scaled_init_method_normal(sigma, num_layers):
def scaled_init_method_normal(sigma, num_layers, use_mup_outer=False, mup_init_scale=1.0):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)

def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=std)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=std)

return init_


# orthogonal init does not support fp16, so have to patch it
def _orthogonal(tensor, gain=1):

if tensor.ndimension() < 2:
raise ValueError("Only tensors with 2 or more dimensions are supported")

Expand Down Expand Up @@ -67,75 +84,111 @@ def _orthogonal(tensor, gain=1):
return tensor


def orthogonal_init_method(n_layers=1):
def orthogonal_init_method(n_layers=1, use_mup=False, mup_init_scale=1.0):
"""Fills the input Tensor with a (semi) orthogonal matrix, as described in
Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe, A. et al. (2013)
Optionally scaling by number of layers possible, as introduced in OBST - Nestler et. al. (2021, to be released)"""

if use_mup:
raise ValueError("Orthogonal init needs to be patched to support mup. Disable mup or use a different init method to avoid this error")

def init_(tensor):
return _orthogonal(tensor, math.sqrt(2 / n_layers))

return init_


def xavier_uniform_init_method():
def xavier_uniform_init_method(use_mup_outer=False, mup_init_scale=1.0):
"""Fills the input Tensor with values according to the method described in Understanding the difficulty of
training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution."""

def init_(tensor):
return torch.nn.init.xavier_uniform_(tensor)
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.xavier_uniform_(tensor)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.xavier_uniform_(tensor)


return init_


def xavier_normal_init_method():
def xavier_normal_init_method(use_mup_outer=False, mup_init_scale=1.0):
"""Fills the input Tensor with values according to the method described in Understanding the difficulty of
training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution."""

def init_(tensor):
return torch.nn.init.xavier_normal_(tensor)
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.xavier_normal_(tensor)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.xavier_normal_(tensor)

return init_


def small_init_init_method(dim):
def small_init_init_method(dim, use_mup_outer=False, mup_init_scale=1.0):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
std = math.sqrt(2 / (5 * dim))

def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=std)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=std)

return init_


def wang_init_method(n_layers, dim):
def wang_init_method(n_layers, dim, use_mup_outer=False, mup_init_scale=1.0):
std = 2 / n_layers / math.sqrt(dim)

def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=std)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=std)

return init_


def get_init_methods(args):

if args.use_mup:
try:
import mup
except ModuleNotFoundError:
print("Please install mup https://github.com/microsoft/mup")
raise Exception

def _get(name):
if name == "normal":
return init_method_normal(args.init_method_std)
return init_method_normal(args.init_method_std, args.use_mup, args.mup_init_scale)
elif name == "scaled_normal":
return scaled_init_method_normal(args.init_method_std, args.num_layers)
return scaled_init_method_normal(args.init_method_std, args.num_layers, args.use_mup, args.mup_init_scale)
elif name == "orthogonal":
return orthogonal_init_method()
return orthogonal_init_method(args.use_mup, args.mup_init_scale)
elif name == "scaled_orthogonal":
return orthogonal_init_method(args.num_layers)
return orthogonal_init_method(args.num_layers, args.use_mup, args.mup_init_scale)
elif name == "xavier_uniform":
return xavier_uniform_init_method()
return xavier_uniform_init_method(args.use_mup, args.mup_init_scale)
elif name == "xavier_normal":
return xavier_normal_init_method()
return xavier_normal_init_method(args.use_mup, args.mup_init_scale)
elif name == "wang_init":
return wang_init_method(args.num_layers, args.hidden_size)
return wang_init_method(args.num_layers, args.hidden_size, args.use_mup, args.mup_init_scale)
elif name == "small_init":
return small_init_init_method(args.hidden_size)
return small_init_init_method(args.hidden_size, args.use_mup, args.mup_init_scale)
else:
raise NotImplementedError(f"Unknown init method {name}")

Expand Down
Loading

0 comments on commit 0535bfb

Please sign in to comment.