-
Notifications
You must be signed in to change notification settings - Fork 236
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
103 additions
and
42 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#!/bin/bash | ||
|
||
# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Llama2-7b | ||
|
||
# The flow of this file is as follows: | ||
# 1. Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory. Convert this PyTorch checkpoint into Orbax checkpoint format for use in MaxText. | ||
# 2. Run decoding, finetuning of Llama2-7b with this converted checkpoint. Also, run pretraining of Llama2-7b. | ||
# 3. Run decoding from the finetuned weights | ||
# 4. Convert the scanned checkpoint from step #1 into unscanned checkpoint format and run more efficient decoding. | ||
|
||
|
||
set -ex | ||
idx=$(date +%Y-%m-%d-%H-%M) | ||
|
||
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run | ||
export BASE_OUTPUT_DIRECTORY=gs:https://runner-maxtext-logs | ||
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data | ||
export DATASET_PATH=gs:https://maxtext-dataset | ||
export ASYNC_CHECKPOINTING=false | ||
|
||
# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU | ||
pip install torch --index-url https://download.pytorch.org/whl/cpu | ||
|
||
# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint | ||
export META_CHECKPOINT_PATH=gs:https://maxtext-llama/llama2-7b/meta-ckpt | ||
|
||
# In the following command, we are copying Meta's checkpoint into a local directory `tmp`. | ||
# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 MaxText/llama_or_mistral_ckpt.py` | ||
gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/ | ||
|
||
# `CONVERTED_CHECKPOINT_PATH` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `CONVERTED_CHECKPOINT_PATH` to a GCS bucket that you own | ||
export CONVERTED_CHECKPOINT_PATH=gs:https://maxtext-llama/test/${idx}/decode-ckpt-maxtext | ||
|
||
#Next, run the conversion script `MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path` | ||
python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/meta-ckpt --model-size llama2-7b --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH} | ||
|
||
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory exactly inside `CONVERTED_CHECKPOINT_PATH`. This way it is easier to use this path in the `train.py` and `decode.py` commands | ||
export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items | ||
|
||
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. | ||
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. | ||
export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint_${idx} | ||
python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true | ||
|
||
# Like before, we define `UNSCANNED_CKPT_PATH` to refer to the checkpoint subdirectory exactly | ||
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items | ||
|
||
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint converted directly from Meta's PyTorch checkpoint aka `CONVERTED_CHECKPOINT`. Note that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` | ||
# We compare our decoded results by asserting with golden PyTorch outputs using `autoregressive_decode_assert` | ||
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=runner_decode_unscanned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product scan_layers=false | ||
|
||
|
||
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` | ||
# We compare our decoded results by asserting with golden PyTorch outputs using `autoregressive_decode_assert` | ||
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product | ||
|
||
# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning | ||
python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_finetuning_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1 checkpoint_period=5 | ||
|
||
# We also run pre-training of Llama2-7b, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from | ||
python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_pretraining_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1 | ||
|
||
# Now, run decoding on the checkpoint generated from our finetune run. Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert | ||
# the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run, say the checkpoint saved at finetuning step #5 | ||
# Also, `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding | ||
export PARAMETER_CHECKPOINT_RUN=generate_param_only_checkpoint_${idx} | ||
python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/runner_finetuning_${idx}/checkpoints/5/items run_name=${PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true | ||
|
||
# Like before, we define `NEW_CKPT_PATH` to refer to the checkpoint subdirectory exactly | ||
export NEW_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items | ||
|
||
# We run decoding on the fine-tuned parameter checkpoint | ||
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${NEW_CKPT_PATH} run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
<!-- | ||
Copyright 2023 Google LLC | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
https://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
--> | ||
|
||
## About Llama2 | ||
|
||
MaxText supports [Llama2](https://llama.meta.com/llama2) pretraining, finetuning and decoding for its 7B and 70B flavors. To get started on decoding and finetuning of Llama2, you will first need to download weights along with its tokenizer from [Meta](https://llama.meta.com/llama-downloads). | ||
|
||
The file [test_llama2_7b.sh](https://github.com/google/maxtext/end_to_end/test_llama2_7b.sh) provides details on how to convert the PyTorch weights in orbax checkpoint format, and thereafter use it for running decoding and finetuning. [test_llama2_7b.sh](https://github.com/google/maxtext/end_to_end/test_llama2_7b.sh) also shows how to run pretraining and also how to run decoding on the finetuned model checkpoint. | ||
|
||
### MaxText supports pretraining and finetuning with high performance. | ||
|
||
Model Flop utilization for training on v5e and v5p and v4 TPUs with MaxText. | ||
|
||
|
||
| Model | v4-128 (bf16) | v5p-128 (bf16) | v5e-256 (bf16) | | ||
| ---------- | -------------- | -------------- | -------------- | | ||
| Llama2-70b | 57% | 65% | 57% | |