Skip to content

Commit

Permalink
20B Release (#533)
Browse files Browse the repository at this point in the history
* change the eye links to mystic mirror

* update README and neox_arguments.md

* clarify soem arguments in prepare / preprocess_data.py

* add 20B config

* add `sample_input_file` and `sample_output_file` to cmd line args

* update README.md

* Add paper link
  • Loading branch information
sdtblck committed Feb 10, 2022
1 parent ac3d808 commit c560814
Show file tree
Hide file tree
Showing 8 changed files with 515 additions and 201 deletions.
235 changes: 150 additions & 85 deletions README.md

Large diffs are not rendered by default.

110 changes: 110 additions & 0 deletions configs/20B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# DISCLAIMER: This is the configuration file for the GPT-NeoX-20B model as it was trained on 96x 40GB A100
# GPUs. Depending on your system configuration, you may need to change some parameters in order to fit
# the model in memory.

{
# Tokenizer / checkpoint settings - you will need to change these to the location you have them saved in
"vocab-file": "./20B_checkpoints/20B_tokenizer.json",
"save": "./20B_checkpoints",
"load": "./20B_checkpoints",

# If finetuning, edit the following to the location of your finetuning dataset:
"data-path": "./data/pile_20B_tokenizer/pile_20B_tokenizer_text_document",

# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
"pipe-parallel-size": 4,
"model-parallel-size": 2,

# model settings
"num-layers": 44,
"hidden-size": 6144,
"num-attention-heads": 64,
"seq-length": 2048,
"max-position-embeddings": 2048,
"norm": "layernorm",
"pos-emb": "rotary",
"rotary_pct": 0.25,
"no-weight-tying": true,
"gpt_j_residual": true,
"output_layer_parallelism": "column",
"scaled-upper-triang-masked-softmax-fusion": true,
"bias-gelu-fusion": true,

# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.97e-4,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},

"min_lr": 0.97e-5,
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 1260000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 1260000000,
"contiguous_gradients": True,
"cpu_offload": False
},

# batch / data settings (assuming 96 GPUs)
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 32,
"data-impl": "mmap",
"split": "995,4,1",

# activation checkpointing
"checkpoint-activations": true,
"checkpoint-num-layers": 1,
"partition-activations": false,
"synchronize-each-layer": true,

# regularization
"gradient_clipping": 1.0,
"weight-decay": 0.01,
"hidden-dropout": 0,
"attention-dropout": 0,

# precision settings
"fp16": {
"fp16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 12,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train-iters": 150000,
"lr-decay-iters": 150000,

"distributed-backend": "nccl",
"lr-decay-style": "cosine",
"warmup": 0.01,
"save-interval": 500,
"eval-interval": 1000,
"eval-iters": 10,

# logging
"log-interval": 2,
"steps_per_print": 2,
"wall_clock_breakdown": false,

### NEW DATA: ####
"tokenizer_type": "HFTokenizer",
"tensorboard-dir": "./tensorboard",
"log-dir": "./logs",

}
115 changes: 75 additions & 40 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Logging Arguments

- **git_hash**: str

Default = 875f8ad
Default = a593ce2

current git hash of repository

Expand Down Expand Up @@ -237,11 +237,11 @@ Model Arguments



- **norm**: typing.Literal['layernorm', 'rmsnorm', 'scalenorm', 'apexlayernorm']
- **norm**: typing.Literal['layernorm', 'rmsnorm', 'scalenorm']

Default = layernorm

Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm", "apexlayernorm".
Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm".



Expand Down Expand Up @@ -269,7 +269,7 @@ Model Arguments



- **pos_emb**: typing.Literal['learned', 'rotary', 'sinusoidal', 'rpe', 'none']
- **pos_emb**: typing.Literal['learned', 'rotary', 'sinusoidal', 'rpe', 'alibi', 'none']

Default = learned

Expand Down Expand Up @@ -375,14 +375,6 @@ Model Arguments



- **apply_residual_connection_post_layernorm**: bool

Default = False

If set, use original BERT residual connection ordering.



- **activation**: typing.Literal['gelu', 'geglu', 'relu', 'softsign', 'swish', 'mish']

Default = gelu
Expand Down Expand Up @@ -506,6 +498,41 @@ Model Arguments



- **gpt_j_residual**: bool

Default = False

If false, we use the conventional residual path:
x = x + attn(ln1(x))
x = x + mlp(ln2(x))
Otherwise, we use the residual path from GPT-J, which offers a slight speedup:
x = ln(x)
x = x + attn(x) + mlp(x)



- **soft_prompt_tuning**: dict

Default = None

Dictionary configuring the soft prompt tuning parameters.
If enabled, will train *only* the soft prompt, and freezes the rest of the model.
parameters in the dict are:
'enabled': bool = True # enables soft prompting
'num_tokens': int = 10 # length of the soft prompt in tokens
'init_string': str = '' # if provided, initialize the soft prompt with the word embeddings of this string
'init_range': float = 0.5 # if no init string is provided, initialize the soft prompt with a uniform distribution between -init_range and init_rang



- **output_layer_parallelism**: typing.Literal['row', 'column']

Default = row

Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)



## NeoXArgsOptimizer

Optimizer Arguments
Expand All @@ -520,6 +547,14 @@ Optimizer Arguments



- **use_bnb_optimizer**: bool

Default = False

Whether to enable the bitsandbytes optimizers



- **zero_stage**: int

Default = None
Expand Down Expand Up @@ -614,22 +649,6 @@ Misc. Arguments



- **reset_position_ids**: bool

Default = False

Reset posistion ids after end-of-document token.



- **reset_attention_mask**: bool

Default = False

Reset self attention mask after end-of-document token.



- **eod_mask_loss**: bool

Default = False
Expand Down Expand Up @@ -734,6 +753,14 @@ Misc. Arguments



- **global_num_gpus**: int

Default = None

Set during launching



## NeoXArgsParallelism

Parallelism Arguments
Expand Down Expand Up @@ -797,7 +824,7 @@ Text Generation arguments

- **text_gen_type**: str

Default = None
Default = unconditional

How to generate text/sample the model.
Options: `unconditional`, `input-file`, `interactive`
Expand Down Expand Up @@ -846,17 +873,17 @@ Text Generation arguments

- **sample_output_file**: str

Default = None
Default = samples.txt

Output file



- **num_samples**: int

Default = 0
Default = 1

Number of samples to generate unconditionally, defaults to 0 and interactive conditional sampling
Number of samples to generate unconditionally, defaults to 1 and interactive conditional sampling



Expand Down Expand Up @@ -885,14 +912,6 @@ Text Generation arguments



- **char_level_ppl**: bool

Default = False

Whether to calculate character level perplexity as well as token level perplexity. (may incur a time cost)



## NeoXArgsTokenizer

Tokenizer Arguments
Expand Down Expand Up @@ -1037,6 +1056,14 @@ Training Arguments



- **config_files**: dict

Default = None

Store of original config files mapping config filename to file contents



- **load**: str

Default = None
Expand Down Expand Up @@ -1327,6 +1354,14 @@ Training Arguments



- **char_level_ppl**: bool

Default = False

Whether to calculate character level perplexity as well as token level perplexity. (may incur a time cost)



## NeoXArgsDeepspeedConfig

Args for deepspeed config
Expand Down
34 changes: 30 additions & 4 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,15 @@ def from_ymls(cls, paths_to_yml_files: List[str], overwrite_values: Dict = None)
# load original config files to save unchanged with checkpoint
# saving the original config retains comments
filename = os.path.basename(conf_file_name)
assert filename not in config_files, "At least two config files have the same filename. This will result in conflicts when saving out configs with the checkpoint in one single directory. Please use unique names for configs."
assert (
filename not in config_files
), "At least two config files have the same filename. This will result in conflicts when saving out configs with the checkpoint in one single directory. Please use unique names for configs."
config_files[filename] = open(conf_file_name).read()

# add config file content to neox args to make them accessible in code
# this is used when saving checkpoints
config["config_files"] = config_files

# Configuration parameters not specified
params_not_in_config = sorted(
list(set(cls.__dataclass_fields__.keys()) - set(config.keys()))
Expand Down Expand Up @@ -279,6 +281,21 @@ def consume_deepy_args(cls):
default=None,
help="prefix to append to eval results file",
)
group = parser.add_argument_group(title="Generation args")
group.add_argument(
"-i",
"--sample_input_file",
type=str,
default=None,
help="Optionally overwrite `sample_input_file` for generate.py",
)
group.add_argument(
"-o",
"--sample_output_file",
type=str,
default=None,
help="Optionally overwrite `sample_output_file` for generate.py",
)
args_parsed = parser.parse_args()

# Validate user_script exists
Expand All @@ -301,8 +318,10 @@ def consume_deepy_args(cls):
overwrite_values[k] = v

# load args
neox_args = cls.from_ymls(paths_to_yml_files=conf_files, overwrite_values=overwrite_values)

neox_args = cls.from_ymls(
paths_to_yml_files=conf_files, overwrite_values=overwrite_values
)

if neox_args.wandb_group is not None:
# concat the wandb group name with a uid to make sure it's unique
import wandb
Expand Down Expand Up @@ -770,6 +789,13 @@ def calculate_derived(self):
if self.test_data_paths and (self.test_data_weights is None):
self.test_data_weights = [1.0] * len(self.test_data_paths)

# if a sample input file is provided, default text_gen_type type to input-file
if self.text_gen_type is None:
if self.sample_input_file:
self.update_value("text_gen_type", "input-file")
else:
self.update_value("text_gen_type", "unconditional")

############################################################################################################################
# start of validation functions

Expand Down
Loading

0 comments on commit c560814

Please sign in to comment.