Skip to content

Commit

Permalink
Add more tests for Mixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
RissyRan committed May 3, 2024
1 parent e8b53e5 commit ffcd34c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
16 changes: 12 additions & 4 deletions end_to_end/tpu/mixtral/8x7b/1_test_mixtral.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# This file, combined with step 2 in the same directory, runs on daily basis and demonstrates:
# 1. Converts the Mistral PyTorch checkpoint to MaxText(orbax) format using a CPU VM.
# 2. Takes the MaxText (orbax) checkpoint to run inference and fine-tuning on a TPU VM.
# 2. Takes the MaxText(orbax) checkpoint to run inference, fine-tuning, and pre-training on a TPU VM.

# The flow of this file is to convert the Mistral PyTorch checkpoint to MaxText (orbax) format using a CPU VM.

Expand All @@ -18,8 +18,16 @@ if [ -z "${BASE_OUTPUT_PATH}" ]; then
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi

# Download checkpoint, convert it to MaxText(orbax) format
# Download checkpoint
pip3 install torch
gcloud storage cp -r gs:https://maxtext-external/mixtral-8x7B-v0.1-Instruct /tmp
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/mixtral-8x7B-v0.1-Instruct --model-size mixtral-8x7b --maxtext-model-path ${BASE_OUTPUT_PATH}${MODEL_VARIATION}/decode-ckpt-maxtext/
echo "Wrote MaxText compatible checkpoint to ${BASE_OUTPUT_PATH}${MODEL_VARIATION}/decode-ckpt-maxtext"

# Convert it to MaxText(orbax) format - scanned ckpt
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/mixtral-8x7B-v0.1-Instruct --model-size mixtral-8x7b --maxtext-model-path ${BASE_OUTPUT_PATH}${MODEL_VARIATION}/scanned_ckpt/
echo "Wrote MaxText compatible scanned checkpoint to ${BASE_OUTPUT_PATH}${MODEL_VARIATION}/scanned_ckpt"

# Generate unscanned ckpt for efficient decoding test
export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}${MODEL_VARIATION}/scanned_ckpt/0/items
export RUN_NAME=unscanned_ckpt
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=${RUN_NAME} model_name='mixtral-8x7b' force_unroll=true
echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints"
18 changes: 13 additions & 5 deletions end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

# This file, combined with step 1 in the same directory, runs on daily basis and demonstrates:
# 1. Converts the Mistral PyTorch checkpoint to MaxText(orbax) format using a CPU VM.
# 2. Takes the MaxText(orbax) checkpoint to run inference and fine-tuning on a TPU VM.
# 2. Takes the MaxText(orbax) checkpoint to run inference, fine-tuning, and pre-training on a TPU VM.

# The flow of this file is to take the MaxText(orbax) checkpoint to run inference and fine-tuning on a TPU VM.
# The flow of this file is to take the MaxText(orbax) checkpoint to run inference, fine-tuning, and pre-training on a TPU VM.
# Please make sure you have run end_to_end/tpu/mixtral/8x7b/1_test_mixtral.sh before running commands from this file.

# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh
Expand All @@ -24,8 +24,16 @@ export M_BASE_OUTPUT_DIRECTORY=${BASE_OUTPUT_PATH}${MODEL_VARIATION}
export M_DATASET_PATH=gs:https://maxtext-dataset
export M_ASYNC_CHECKPOINTING=false

# Run decoding
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${BASE_OUTPUT_PATH}${MODEL_VARIATION}/decode-ckpt-maxtext/0/items run_name=decoding per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path=gs:https://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=24 prompt="[INST] I love to [/INST]" autoregressive_decode_assert="That's great to hear! I love to learn new things" attention=dot_product
# `SCANNED_CHECKPOINT` refers to the checkpoint that used for both `train.py` and `decode.py`
export SCANNED_CHECKPOINT=${M_BASE_OUTPUT_DIRECTORY}/scanned_ckpt/0/items

# Run decoding with converted ckpt
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path=gs:https://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=24 prompt="[INST] I love to [/INST]" autoregressive_decode_assert="That's great to hear! I love to learn new things" attention=dot_product

# Run fine-tuning
python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=${BASE_OUTPUT_PATH}${MODEL_VARIATION}/decode-ckpt-maxtext/0/items run_name=fine_tuning per_device_batch_size=1 model_name=mixtral-8x7b ici_tensor_parallelism=4 ici_fsdp_parallelism=16 steps=10 max_target_length=1024 tokenizer_path=gs:https://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral
python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=fine_tuning per_device_batch_size=1 model_name=mixtral-8x7b ici_tensor_parallelism=4 ici_fsdp_parallelism=16 steps=10 max_target_length=1024 tokenizer_path=gs:https://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral checkpoint_period=5

# Run pre-training without load_parameters_path
python3 MaxText/train.py MaxText/configs/base.yml run_name=pre_training per_device_batch_size=1 model_name=mixtral-8x7b ici_tensor_parallelism=4 ici_fsdp_parallelism=16 steps=5 max_target_length=1024 tokenizer_path=gs:https://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral

# TODO(ranran): Run decoding with unscanned ckpt

0 comments on commit ffcd34c

Please sign in to comment.