Skip to content

Commit

Permalink
Update Llama2 readme and test
Browse files Browse the repository at this point in the history
  • Loading branch information
A9isha committed Mar 25, 2024
1 parent 6dfd857 commit 6e70674
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 42 deletions.
42 changes: 0 additions & 42 deletions end_to_end/test_llama2.sh

This file was deleted.

73 changes: 73 additions & 0 deletions end_to_end/test_llama2_7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/bin/bash

# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Llama2-7b

# The flow of this file is as follows:
# 1. Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory. Convert this PyTorch checkpoint into Orbax checkpoint format for use in MaxText.
# 2. Run decoding, finetuning of Llama2-7b with this converted checkpoint. Also, run pretraining of Llama2-7b.
# 3. Run decoding from the finetuned weights
# 4. Convert the scanned checkpoint from step #1 into unscanned checkpoint format and run more efficient decoding.


set -ex
idx=$(date +%Y-%m-%d-%H-%M)

# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_DIRECTORY=gs:https://runner-maxtext-logs
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs:https://maxtext-dataset
export ASYNC_CHECKPOINTING=false

# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU
pip install torch --index-url https://download.pytorch.org/whl/cpu

# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint
export META_CHECKPOINT_PATH=gs:https://maxtext-llama/llama2-7b/meta-ckpt

# In the following command, we are copying Meta's checkpoint into a local directory `tmp`.
# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 MaxText/llama_or_mistral_ckpt.py`
gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/

# `CONVERTED_CHECKPOINT_PATH` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `CONVERTED_CHECKPOINT_PATH` to a GCS bucket that you own
export CONVERTED_CHECKPOINT_PATH=gs:https://maxtext-llama/test/${idx}/decode-ckpt-maxtext

#Next, run the conversion script `MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path`
python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/meta-ckpt --model-size llama2-7b --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}

# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory exactly inside `CONVERTED_CHECKPOINT_PATH`. This way it is easier to use this path in the `train.py` and `decode.py` commands
export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items

# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`.
export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint_${idx}
python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true

# Like before, we define `UNSCANNED_CKPT_PATH` to refer to the checkpoint subdirectory exactly
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items

# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint converted directly from Meta's PyTorch checkpoint aka `CONVERTED_CHECKPOINT`. Note that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
# We compare our decoded results by asserting with golden PyTorch outputs using `autoregressive_decode_assert`
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=runner_decode_unscanned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product scan_layers=false


# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
# We compare our decoded results by asserting with golden PyTorch outputs using `autoregressive_decode_assert`
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product

# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning
python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_finetuning_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1 checkpoint_period=5

# We also run pre-training of Llama2-7b, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from
python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_pretraining_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1

# Now, run decoding on the checkpoint generated from our finetune run. Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert
# the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run, say the checkpoint saved at finetuning step #5
# Also, `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding
export PARAMETER_CHECKPOINT_RUN=generate_param_only_checkpoint_${idx}
python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/runner_finetuning_${idx}/checkpoints/5/items run_name=${PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true

# Like before, we define `NEW_CKPT_PATH` to refer to the checkpoint subdirectory exactly
export NEW_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items

# We run decoding on the fine-tuned parameter checkpoint
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${NEW_CKPT_PATH} run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false
30 changes: 30 additions & 0 deletions getting_started/Run_Llama2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!--
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

## About Llama2

MaxText supports [Llama2](https://llama.meta.com/llama2) pretraining, finetuning and decoding for its 7B and 70B flavors. To get started on decoding and finetuning of Llama2, you will first need to download weights along with its tokenizer from [Meta](https://llama.meta.com/llama-downloads).

The file [test_llama2_7b.sh](https://github.com/google/maxtext/end_to_end/test_llama2_7b.sh) provides details on how to convert the PyTorch weights in orbax checkpoint format, and thereafter use it for running decoding and finetuning. [test_llama2_7b.sh](https://github.com/google/maxtext/end_to_end/test_llama2_7b.sh) also shows how to run pretraining and also how to run decoding on the finetuned model checkpoint.

### MaxText supports pretraining and finetuning with high performance.

Model Flop utilization for training on v5e and v5p and v4 TPUs with MaxText.


| Model | v4-128 (bf16) | v5p-128 (bf16) | v5e-256 (bf16) |
| ---------- | -------------- | -------------- | -------------- |
| Llama2-70b | 57% | 65% | 57% |

0 comments on commit 6e70674

Please sign in to comment.