Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replication instructions for training #53

Merged
merged 4 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gpt-neox
Submodule gpt-neox updated 90 files
+1 −8 .gitignore
+3 −3 README.md
+3 −11 configs/neox_arguments.md
+0 −101 configs/v0.4/1.4b-1000-early.yml
+0 −101 configs/v0.4/1.4b-12000-early.yml
+0 −101 configs/v0.4/1.4b-128-early.yml
+0 −101 configs/v0.4/1.4b-2000-early.yml
+0 −101 configs/v0.4/1.4b-24000-early.yml
+0 −101 configs/v0.4/1.4b-256-early.yml
+0 −101 configs/v0.4/1.4b-48000-early.yml
+0 −101 configs/v0.4/1.4b-6000-early.yml
+0 −101 configs/v0.4/1.4b-64-early.yml
+0 −101 configs/v0.4/1.4b-71000-early.yml
+0 −101 configs/v0.4/1.4b-8-early.yml
+0 −99 configs/v0.5/410M_flash.yml
+0 −94 configs/v0.5/6-9B.yml
+0 −97 configs/v0.5/6-9B_flash.yml
+0 −99 configs/v0.5/6-9B_universal_flash.yml
+0 −99 configs/v0.5/70M_flash.yml
+0 −106 configs/v0.5/llama_7B.yml
+0 −106 configs/v0.5/llama_7B_correct-scheduler.yml
+0 −106 configs/v0.5/llama_7B_correct-scheduler_3e-5.yml
+0 −106 configs/v0.5/llama_7B_correct-scheduler_low-lr.yml
+0 −107 configs/v0.5/llama_7B_correct-scheduler_resume.yml
+0 −107 configs/v0.5/llama_7B_correct-scheduler_validation.yml
+0 −104 configs/v0.5/llama_7B_resume.yml
+0 −102 configs/v0.5/llama_7B_scratch.yml
+0 −106 configs/v0.6/llama_7b.yml
+0 −106 configs/v0.6/llama_7b_1e-5.yml
+127 −12 megatron/checkpointing.py
+14 −16 megatron/learning_rates.py
+32 −19 megatron/model/positional_embeddings.py
+113 −67 megatron/model/transformer.py
+0 −3 megatron/neox_arguments/deepspeed_args.py
+34 −4 megatron/neox_arguments/neox_args.py
+1 −1 megatron/tokenizer/train_tokenizer.py
+23 −5 megatron/training.py
+1 −1 requirements/requirements-flashattention.txt
+2 −0 requirements/requirements.txt
+0 −60 slurm/v0.4/1.4b-1000-early.sh
+0 −60 slurm/v0.4/1.4b-12000-early.sh
+0 −60 slurm/v0.4/1.4b-128-early.sh
+0 −60 slurm/v0.4/1.4b-2000-early.sh
+0 −60 slurm/v0.4/1.4b-24000-early.sh
+0 −60 slurm/v0.4/1.4b-256-early.sh
+0 −60 slurm/v0.4/1.4b-48000-early.sh
+0 −60 slurm/v0.4/1.4b-6000-early.sh
+0 −60 slurm/v0.4/1.4b-64-early.sh
+0 −60 slurm/v0.4/1.4b-71000-early.sh
+0 −60 slurm/v0.4/1.4b-8-early.sh
+ slurm/v0.5/01_all.png
+ slurm/v0.5/02_fixed-checkpoint.png
+0 −39 slurm/v0.5/410m_flash.sh
+0 −35 slurm/v0.5/6-9b.sh
+0 −38 slurm/v0.5/6-9b_flash.sh
+0 −39 slurm/v0.5/70m_flash.sh
+0 −21 slurm/v0.5/README.md
+0 −42 slurm/v0.5/convert_llama_to_hf.sh
+0 −39 slurm/v0.5/llama_7b.sh
+0 −39 slurm/v0.5/llama_7b_correct-scheduler.sh
+0 −39 slurm/v0.5/llama_7b_correct-scheduler_3e-5.yml
+0 −39 slurm/v0.5/llama_7b_correct-scheduler_low-lr.sh
+0 −39 slurm/v0.5/llama_7b_correct-scheduler_resume.sh
+0 −39 slurm/v0.5/llama_7b_correct-scheduler_validation.sh
+0 −39 slurm/v0.5/llama_7b_resume.sh
+0 −39 slurm/v0.5/llama_7b_scratch.sh
+0 −42 slurm/v0.6/convert_llama_to_hf.sh
+0 −39 slurm/v0.6/llama_7b.yml
+0 −39 slurm/v0.6/llama_7b_1e-5.yml
+0 −15 tools/README.md
+0 −50 tools/checkpoints/README.md
+0 −434 tools/checkpoints/convert_v1.0_to_hf.py
+0 −225 tools/checkpoints/deepspeed_to_deepspeed.py
+0 −360 tools/checkpoints/ds_to_universal.py
+0 −128 tools/checkpoints/inspect_ds_checkpoint.py
+0 −441 tools/convert_llama_sequential_to_hf.py
+2 −2 tools/convert_raw_llama_weights_to_neox.py
+0 −0 tools/convert_sequential_to_hf.py
+18 −12 tools/corpora.py
+0 −0 tools/inspect_checkpoints.py
+0 −0 tools/kill.sh
+0 −0 tools/killall.sh
+0 −0 tools/merge20b.py
+0 −0 tools/merge_datasets.py
+0 −0 tools/merge_mp_partitions.py
+5 −9 tools/preprocess_data.py
+0 −0 tools/sync.sh
+0 −0 tools/sync_cmd.sh
+0 −0 tools/syncdir.sh
+0 −0 tools/upload.py
39 changes: 39 additions & 0 deletions pretraining/34b_launch_script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash
#... your SLURM arguments here
#SBATCH --nodes=32
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=12
#SBATCH --gres=gpu:8
#SBATCH --output=34b_replication_%j.out
#SBATCH --error=34b_replication_%j.out
#SBATCH --exclusive
#SBATCH --open-mode=append
#SBATCH --requeue

# setup the environment using the script we created before
source /fsx/proj-mathlm/conda_setup_deeperspeed.sh
#source /fsx/quentin/setup.sh

ds_report

# set distributed env variable flags such as NCCL_DEBUG here

export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=12802
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`

# Move to the gpt-neox install
TRAIN_PATH=/path/to/gpt-neox
cd $TRAIN_PATH

# Write the hostfile for this job here
# Should write to a hostfile that contains lines of format `<machine IP> slots=<NUM_GPUS_PER_NODE>`
bash /helper/script/write_hostfile.sh
export DLTS_HOSTFILE=path/to/hostfile/hosts_$SLURM_JOBID


# launch distributed job. If using `"deepspeed_slurm": true` and `"launcher": "slurm"` on a SLURM cluster,
# then NeoX will handle the creation of a distributed run across 256 gpus.
python $TRAIN_PATH/deepy.py $TRAIN_PATH/train.py \
--conf_dir /path/to/math-lm/pretraining llemma_34b.yml data_mixture.yml
67 changes: 67 additions & 0 deletions pretraining/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# LLeMA Pretraining

This subfolder contains instructions to replicate pretraining of the LLeMA models.

Training was performed across 256 A100 GPUs using the GPT-NeoX library. We include configuration files and sample SLURM job script for the library to replicate training on a SLURM-managed cluster.


## Replicating Training


### Set up environment

We provide a file containing a dump of our training environment.

You can install all required packages via
```bash
pip install -r env_dump.txt
```
Make sure you are installing https://github.com/EleutherAI/DeeperSpeed/tree/new-fix for your DeepSpeed version and install fused kernels for GPT-NeoX via `python ./megatron/fused_kernels/setup.py install` from within your GPT-NeoX install.


### Converting Llama 2 checkpoints into NeoX format

First, download CodeLlama 7b or 34b from the Meta AI repo and rename the download folder to 7B or 34B within the CodeLlama repository.

Then, to convert either model into the format expected by GPT-NeoX for checkpoints:

Sample command for 7b Meta->NeoX format:
```bash
python convert_raw_llama_weights_to_hf.py --input_dir /path/to/codellama/repo --config_file /path/to/this/repo/math-lm/pretraining/llemma_7b.yml --output_dir /path/to/save/into/ --num_output_shards {TP_DEGREE, we use 2}
```

Sample command for 34b Meta->NeoX format:
(Requires large amounts of GPU VRAM or CPU RAM. Pass `CUDA_VISIBLE_DEVICES=""` to perform conversion on CPU. 34b conversion may take a while)
```bash
CUDA_VISIBLE_DEVICES="" python convert_raw_llama_weights_to_hf.py --input_dir /path/to/codellama/repo --config_file /path/to/this/repo/math-lm/pretraining/llemma_34b.yml --output_dir /path/to/save/into/ --num_output_shards {TP_DEGREE, we use 8}
```


### Check Out Codebase

Next, check out the commit used to train the model you are replicating.

* 7b / 34b: https://github.com/EleutherAI/gpt-neox/commit/e59c873ee779df2d7f182deb6ad34f290a077ea4

### Launching Training

Then, edit the provided YML files to set paths based on your own system's saved locations for checkpoints and data files, and edit the SLURM job script as specified (using ) or run the job across multiple nodes using your own system's orchestration.
haileyschoelkopf marked this conversation as resolved.
Show resolved Hide resolved

**Tip**: Note that the global batch size will be scaled by your number of nodes. Therefore, if running on a number of nodes different from 32 you should scale gradient accumulation steps accordingly.

We used a batch size of 4M tokens. To calculate global batch size, you should compute `seq_len * num_gpus * ( train_microbatch_size_per_gpu * gradient_accumulation_steps) / (model_parallel_size * max(pipeline_parallel_size, 1))` .


## Contents

The files in this folder are as follows:

* `34b_launch_script.sh` contains a skeleton SLURM job script to launch training with NeoX across 32 nodes.

* `data_mixture.yml` contains a list of the domain weights for the final training run.

* `llemma_7b.yml` is a cleaned-up version of the config file used to train Llemma-7b.

* `llemma_34b.yml` is a cleaned-up version of the config file used to train Llemma-34b.

* `env_dump.txt` is a dump of the virtual environmment used in training, created via `pip freeze`.
6 changes: 6 additions & 0 deletions pretraining/data_mixture.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"train-data-paths": ["/fsx/proj-mathlm/proof-pile_llama/train/arxiv-rp/arxiv-rp_text_document", "/fsx/proj-mathlm/open-web-math-v1.2_llama/train/open-web-math/open-web-math_text_document", "/fsx/proj-mathlm/code-with-proofsteps_llama/train/code-with-proofsteps/code-with-proofsteps_text_document", "/fsx/proj-mathlm/proof-pile_llama/train/pile-sample/pile-sample_text_document", "/fsx/proj-mathlm/code-rp_llama/train/code-rp/code-rp_text_document"],
"train-data-weights": [2, 4, 1, 0.147368, 0.221053],
"valid-data-paths": ["/fsx/proj-mathlm/proof-pile_llama/validation/arxiv-rp/arxiv-rp_text_document", "/fsx/proj-mathlm/open-web-math-v1.2_llama/validation/open-web-math/open-web-math_text_document", "/fsx/proj-mathlm/code-with-proofsteps_llama/validation/code-with-proofsteps/code-with-proofsteps_text_document"],
"test-data-paths": ["/fsx/proj-mathlm/proof-pile_llama/test/arxiv-rp/arxiv-rp_text_document", "/fsx/proj-mathlm/open-web-math-v1.2_llama/test/open-web-math/open-web-math_text_document", "/fsx/proj-mathlm/code-with-proofsteps_llama/test/code-with-proofsteps/code-with-proofsteps_text_document"],
}
116 changes: 116 additions & 0 deletions pretraining/env_dump.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
absl-py==1.4.0
aiohttp==3.8.4
aiosignal==1.3.1
appdirs==1.4.4
async-timeout==4.0.2
attrs==23.1.0
best-download==0.0.9
boto3==1.28.22
botocore==1.31.22
certifi==2023.5.7
chardet==5.1.0
charset-normalizer==3.1.0
click==8.1.4
cmake==3.26.4
colorama==0.4.6
CPCargo @ git+https://github.com/samikama/CPCargo@efbf0a5f2ad893c0eee4caae6098001b74be62d8
DataProperty==1.0.0
datasets==2.13.1
DeepSpeed @ git+https://github.com/EleutherAI/DeeperSpeed.git@new-fix#egg=deepspeed
dill==0.3.6
docker-pycreds==0.4.0
einops==0.6.1
filelock==3.12.2
flash-attn==2.0.0.post1
frozenlist==1.3.3
fsspec==2023.6.0
ftfy==6.1.1
fused-kernels @ file:https:///fsx/hailey/math-lm/gpt-neox/megatron/fused_kernels
gitdb==4.0.10
GitPython==3.1.32
hf_transfer==0.1.3
hjson==3.1.0
huggingface-hub==0.16.4
idna==3.4
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.1
jsonlines==3.1.0
lit==16.0.6
lm-dataformat @ git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836
lm-eval==0.3.0
MarkupSafe==2.1.3
mbstrdecoder==1.1.3
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.14
networkx==3.1
ninja==1.11.1
nltk==3.8.1
numexpr==2.8.4
numpy==1.25.0
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
openai==0.27.8
packaging==23.1
pandas==2.0.3
pathtools==0.1.2
pathvalidate==3.0.0
portalocker==2.7.0
protobuf==4.23.4
psutil==5.9.5
py-cpuinfo==9.0.0
pyarrow==12.0.1
pybind11==2.10.4
pycountry==22.3.5
pydantic==1.10.11
pytablewriter==1.0.0
python-dateutil==2.8.2
pytz==2023.3
PyYAML==6.0
regex==2023.6.3
rehash==1.0.1
requests==2.31.0
rouge-score==0.1.2
s3transfer==0.6.1
sacrebleu==1.5.0
safetensors==0.3.1
scikit-learn==1.3.0
scipy==1.11.1
sentencepiece==0.1.99
sentry-sdk==1.28.1
setproctitle==1.3.2
six==1.16.0
smmap==5.0.0
sqlitedict==2.1.0
sympy==1.12
tabledata==1.3.1
tcolorpy==0.1.3
threadpoolctl==3.1.0
tiktoken==0.4.0
tokenizers==0.13.3
torch==2.0.1
tqdm==4.65.0
tqdm-multiprocess==0.0.11
transformers==4.31.0
triton==2.0.0
typepy==1.3.1
typing_extensions==4.7.1
tzdata==2023.3
ujson==5.8.0
urllib3==1.26.16
wandb==0.15.5
watchdog==3.0.0
wcwidth==0.2.6
xxhash==3.2.0
yarl==1.9.2
zstandard==0.21.0
108 changes: 108 additions & 0 deletions pretraining/llemma_34b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
{
"pipe_parallel_size": 0,
"model_parallel_size": 8,
"make_vocab_size_divisible_by": 1,

# model settings
"num_layers": 48,
"hidden_size": 8192,
"num_attention_heads": 64,
"attention_type": "groupedquery",
"num_kv_heads": 8,
"seq_length": 4096,
"max_position_embeddings": 4096,
"pos_emb": "rotary",
"rotary_pct": 1,
"rotary_emb_base": 1000000,
"no_weight_tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,

"attention_config": [[["flash"], 48]],

"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": false,
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",

"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00005,
"betas": [0.9, 0.95],
"eps": 1.0e-8
}
},

"zero_optimization": {
"stage": 1,
"allgather_partitions": true,
"allgather_bucket_size": 1260000000,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1260000000,
"contiguous_gradients": true,
"cpu_offload": false
},

"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 16,
"data_impl": "mmap",

"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

"precision": "bfloat16",
"fp32_allreduce": true,
"bf16": {
"enabled": true
},
"data_types": {
"grad_accum_dtype": "fp32"
},

"train_iters": 12000,
"lr_decay_iters": 12000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"decay_lr_to": 0.033,
"warmup_iters": 500,
"checkpoint_factor": 250,
"eval_interval": 250,
"eval_iters": 25,

"log_interval": 1,
"steps_per_print": 1,
"wall_clock_breakdown": true,

"tokenizer_type": "SPMTokenizer",
"vocab-file": "codellama/tokenizer.model", # use tokenizer.model from Meta CodeLlama download

"save": "/fsx/proj-mathlm/saved-weights/34b_1epoch",
# "load": "" # set to same as "save" to resume from intermediate finetuning step
"load": "/path/to/converted/codellama_34b_weights_with_mp8",

"finetune": true, # set to false once resuming from intermediate finetuning step
"checkpoint_validation_with_forward_pass": true,


"use_wandb": true,
"wandb_group": "34b-codellama-5e-5lr",
"wandb_project": "math-lm",
"wandb_team": "your-teamname-here",
"wandb_host": "https://api.wandb.ai",

"launcher": "slurm",
"deepspeed_slurm": true
}
Loading