LLaMA is a language model developed by Meta. The official implementation can be found here. EasyLM provides a JAX implementation of LLaMA, located at EasyLM/models/llama.
If you are using our OpenLLaMA, you can directly download the EasyLM checkpoints and skip this section. If you are using the official LLaMA weights from Meta, the first step of is to convert the Huggingface transformers LLaMA checkpoint to the EasyLM checkpoint format. To do so, use the following command:
python -m EasyLM.models.llama.convert_hf_to_easylm \
--hf_model='path/to/transformers/llama/checkpoint' \
--output_file='path/to/output/easylm/checkpoint' \
--streaming=True \
--llama.base_model='llama_7b'
This script will convert the official torch checkpoint from Meta to the
streaming checkpoint format used by EasyLM. If you set --streaming
to False
,
the script will output a standard flax checkpoint instead. For more information
about the checkpoint format of EasyLM, see the checkpointing documentation.
After converting the checkpoint and setting up the data, you can fine-tune LLaMA with EasyLM. The training script is implemented in EasyLM/models/llama/llama_train.py. To fine-tune LLaMA, use the following command:
python -m EasyLM.models.llama.llama_train \
--mesh_dim='1,-1,1' \
--llama.base_model='llama_7b' \
--load_checkpoint='params::path/to/easylm/llama/checkpoint' \
...
The following command line options are supported for the training script:
seed
: The random seed to use for the training script.mesh_dim
: The mesh dimensions for the data, fully sharded data and model parallelism. LLaMA uses 3D mesh so a comma separated list of 3 values are required. See the parallelism documentation for more details.dtype
: the float dtype to use for the model activation. Can bebf16
orfp16
orfp32
.- 'params_dtype': the float dtype to use for the model parameters. Can be
bf16
orfp16
orfp32
. total_steps
: The total number of training steps.load_checkpoint
: the checkpoint to load. See the checkpointing documentation for more details.load_dataset_state
: the dataset state to load. Rarely used.log_freq
: the frequency of logging the training metrics.save_model_freq
: the frequency of saving the model checkpoint. The older checkpoints will be overwritten by the newest checkpoint.save_milestone_freq
: the frequency of saving the milestones of model checkpoint. The milestone checkpoints will not be overwritten.eval_steps
: the number of evaluation steps to run to evaluate the model. Setting to 0 will disable the evaluation. Using this requires theeval_dataset
to be properly specified.tokenizer
: Huggingface transformers pretrained tokenizer.train_dataset
: training dataset configuration. See the dataset documentation for more details.eval_dataset
: evaluation dataset configuration. See the dataset documentation for more details.optimizer
: optimizer configuration. See the optimizer documentation for more details.checkpointer
: checkpointer configuration. See the checkpointing documentation for more details.llama
: Specify the LLaMA configuration by starting from a base model. The avaiable configurati