diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 5f08854842..6b90d1b501 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -28,7 +28,12 @@ jobs: - cuda: "118" cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" + - cuda: "121" + cuda_version: 12.1.0 + python_version: "3.10" + pytorch: 2.1.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" steps: - name: Checkout diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9514208b1c..2f0b074501 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,38 +27,56 @@ jobs: - cuda: 118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 + axolotl_extras: + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.10" + pytorch: 2.1.1 axolotl_extras: runs-on: [self-hosted, gpu, docker] steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Docker metadata id: metadata - uses: docker/metadata-action@v3 + uses: docker/metadata-action@v5 with: images: winglian/axolotl + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - - name: Build - uses: docker/build-push-action@v4 + # guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/ + - name: Build and export to Docker + uses: docker/build-push-action@v5 with: context: . + load: true build-args: | BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} CUDA=${{ matrix.cuda }} PYTORCH_VERSION=${{ matrix.pytorch }} file: ./docker/Dockerfile - push: ${{ github.event_name != 'pull_request' }} tags: | ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} labels: ${{ steps.metadata.outputs.labels }} + - name: Unit Tests + run: | + docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ + - name: Push to Docker Hub + if: github.event_name != 'pull_request' + run: | + docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} + latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} + if [ -n "$latest_tag" ]; then + docker push "$latest_tag" + fi + build-axolotl-runpod: needs: build-axolotl if: github.repository_owner == 'OpenAccess-AI-Collective' @@ -80,26 +98,31 @@ jobs: - cuda: 118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 + axolotl_extras: + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.10" + pytorch: 2.1.1 axolotl_extras: runs-on: [self-hosted, gpu, docker] steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Docker metadata id: metadata - uses: docker/metadata-action@v3 + uses: docker/metadata-action@v5 with: images: winglian/axolotl-runpod - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 - name: Build - uses: docker/build-push-action@v4 + uses: docker/build-push-action@v5 with: context: . build-args: | diff --git a/.github/workflows/tests-docker.yml b/.github/workflows/tests-docker.yml new file mode 100644 index 0000000000..ff30d68ea2 --- /dev/null +++ b/.github/workflows/tests-docker.yml @@ -0,0 +1,62 @@ +name: e2e-docker-tests + +on: + pull_request: + paths: + - '**.py' + - 'requirements.txt' + workflow_dispatch: + +jobs: + build-axolotl: + if: github.repository_owner == 'OpenAccess-AI-Collective' + # this job needs to be run on self-hosted GPU runners... + strategy: + fail-fast: false + matrix: + include: + - cuda: 118 + cuda_version: 11.8.0 + python_version: "3.10" + pytorch: 2.0.1 + axolotl_extras: + is_latest: true + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.10" + pytorch: 2.1.1 + axolotl_extras: + runs-on: [self-hosted, gpu, docker] + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Docker metadata + id: metadata + uses: docker/metadata-action@v5 + with: + images: winglian/axolotl + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + # guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/ + - name: Build and export to Docker + uses: docker/build-push-action@v5 + with: + context: . + load: true + build-args: | + BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} + CUDA=${{ matrix.cuda }} + PYTORCH_VERSION=${{ matrix.pytorch }} + file: ./docker/Dockerfile + tags: | + ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} + ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} + labels: ${{ steps.metadata.outputs.labels }} + - name: Unit Tests + run: | + docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ diff --git a/README.md b/README.md index c03eec54b4..172dd558e2 100644 --- a/README.md +++ b/README.md @@ -520,6 +520,14 @@ model_config: type: # linear | dynamic factor: # float +# optional overrides to the bnb 4bit quantization configuration +# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig +bnb_config_kwargs: + # These are default values + llm_int8_has_fp16_weight: false + bnb_4bit_quant_type: nf4 + bnb_4bit_use_double_quant: true + # Whether you are training a 4-bit GPTQ quantized model gptq: true @@ -581,6 +589,9 @@ datasets: # For `completion` datsets only, uses the provided field instead of `text` column field: +# Saves the desired chat template to the tokenizer_config.json for easier inferencing +# Currently supports chatml and inst (mistral/mixtral) +chat_template: chatml # Axolotl attempts to save the dataset as an arrow after packing the data together so # subsequent training attempts load faster, relative path dataset_prepared_path: data/last_run_prepared @@ -632,7 +643,8 @@ max_memory: # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model adapter: lora # If you already have a lora model trained that you want to load, put that here. -# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`. +# This means after training, if you want to test the model, you should set this to the value of `output_dir`. +# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`. lora_model_dir: # LoRA hyperparameters @@ -659,10 +671,6 @@ lora_modules_to_save: # - embed_tokens # - lm_head -# Once you complete training, the model will be saved to the following directory. -# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory. -# Make sure `lora_model_dir` points to this directory if you want to use the trained model. -lora_out_dir: lora_fan_in_fan_out: false # ReLoRA configuration @@ -672,6 +680,7 @@ relora_warmup_steps: # Number of per-restart warmup steps relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings # wandb configuration if you're using it +# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`. wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb wandb_project: # Your wandb project name wandb_entity: # A wandb Team name if using a Team @@ -729,6 +738,9 @@ group_by_length: false # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing gradient_checkpointing: false +# additional kwargs to pass to the trainer for gradient checkpointing +# gradient_checkpointing_kwargs: +# use_reentrant: false # Stop training after this many evaluation losses have increased in a row # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback @@ -798,11 +810,6 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation # Whether to use scaled-dot-product attention # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html sdp_attention: -# Landmark attention (only llama) -landmark_attention: -# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py -# LLaMA only -xpos_rope: # Resume from a specific checkpoint dir resume_from_checkpoint: @@ -925,8 +932,9 @@ accelerate launch -m axolotl.cli.train your_config.yml You can optionally pre-tokenize dataset with the following before finetuning. This is recommended for large datasets. -- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface. -- Use `--debug` to see preprocessed examples. +- Set `dataset_prepared_path:` to a local folder for saving and loading pre-tokenized dataset. +- (Optional): Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface. +- (Optional): Use `--debug` to see preprocessed examples. ```bash python -m axolotl.cli.preprocess your_config.yml @@ -969,6 +977,8 @@ fsdp_config: ##### Weights & Biases Logging +Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`. + - wandb options ```yaml wandb_mode: @@ -981,7 +991,7 @@ wandb_log_model: ##### Special Tokens -It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocubulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: +It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: ```yml special_tokens: @@ -995,9 +1005,12 @@ tokens: # these are delimiters When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary. -### Inference +### Inference Playground + +Axolotl allows you to load your model in an interactive terminal playground for quick experimentation. +The config file is the same config file used for training. -Pass the appropriate flag to the train command: +Pass the appropriate flag to the inference command, depending upon what kind of model was trained: - Pretrained LORA: ```bash @@ -1026,7 +1039,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi Add below flag to train command above ```bash -python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False +python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" ``` If you run out of CUDA memory, you can try to merge in system RAM with diff --git a/docker/Dockerfile b/docker/Dockerfile index 6eea7322ce..f8e0528562 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -10,7 +10,7 @@ ARG PYTORCH_VERSION="2.0.1" ENV PYTORCH_VERSION=$PYTORCH_VERSION RUN apt-get update && \ - apt-get install -y vim curl + apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev WORKDIR /workspace @@ -19,13 +19,15 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets -RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \ else \ pip install -e .[deepspeed,flash-attn]; \ fi +# So we can test the Docker image +RUN pip install pytest + # fix so that git fetch/pull from remote works RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ git config --get remote.origin.fetch diff --git a/docs/rlhf.md b/docs/rlhf.md new file mode 100644 index 0000000000..371a40dbf7 --- /dev/null +++ b/docs/rlhf.md @@ -0,0 +1,35 @@ +# RLHF (Beta) + +### Overview + +Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human +feedback. Various methods include, but not limited to: + +- Proximal Policy Optimization (PPO) (not yet supported in axolotl) +- Direct Preference Optimization (DPO) +- Identity Preference Optimization (IPO) + + +### RLHF using Axolotl + +[!IMPORTANT] +This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality. + +The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML + +#### DPO +```yaml +rl: true +datasets: + - path: Intel/orca_dpo_pairs + split: train + type: intel_apply_chatml + - path: argilla/ultrafeedback-binarized-preferences + split: train + type: argilla_apply_chatml +``` + +#### IPO +```yaml +rl: ipo +``` diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index 1c37b05c13..ea62e9ebfe 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -17,6 +17,7 @@ output_dir: ./out sequence_len: 8192 sample_packing: true pad_to_sequence_len: true +eval_sample_packing: false wandb_project: wandb_entity: diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 15df88f967..11c842d4ee 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -23,6 +23,9 @@ unfrozen_parameters: # - model.layers.3[0-9]+.block_sparse_moe.gate.* # - model.layers.3[0-9]+.block_sparse_moe.experts.* +model_config: + output_router_logits: true + adapter: qlora lora_model_dir: diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 64b26f4fa3..35c79ebf4e 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -11,7 +11,7 @@ datasets: - path: mhenrichsen/alpaca_2k_test type: alpaca dataset_prepared_path: last_run_prepared -val_set_size: 0.05 +val_set_size: 0.1 output_dir: ./qlora-out adapter: qlora diff --git a/examples/tiny-llama/README.md b/examples/tiny-llama/README.md new file mode 100644 index 0000000000..467c06ec87 --- /dev/null +++ b/examples/tiny-llama/README.md @@ -0,0 +1,17 @@ +# Overview + +This is a simple example of how to finetune TinyLlama1.1B using either lora or qlora: + +LoRa: + +``` +accelerate launch -m axolotl.cli.train examples/tiny-llama/lora.yml +``` + +qLoRa: + +``` +accelerate launch -m axolotl.cli.train examples/tiny-llama/qlora.yml +``` + +Both take about 10 minutes to complete on a 4090. diff --git a/examples/llama-2/tiny-llama.yml b/examples/tiny-llama/lora.yml similarity index 87% rename from examples/llama-2/tiny-llama.yml rename to examples/tiny-llama/lora.yml index c72db4e5b2..53d50178a8 100644 --- a/examples/llama-2/tiny-llama.yml +++ b/examples/tiny-llama/lora.yml @@ -1,5 +1,4 @@ -base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T - +base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer is_llama_derived_model: true @@ -17,6 +16,7 @@ output_dir: ./lora-out sequence_len: 4096 sample_packing: true +pad_to_sequence_len: true adapter: lora lora_model_dir: @@ -55,7 +55,6 @@ flash_attention: true warmup_steps: 10 evals_per_epoch: 4 -eval_table_size: saves_per_epoch: 1 debug: deepspeed: @@ -63,6 +62,3 @@ weight_decay: 0.0 fsdp: fsdp_config: special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml new file mode 100644 index 0000000000..dfd1bfca29 --- /dev/null +++ b/examples/tiny-llama/pretrain.yml @@ -0,0 +1,58 @@ +base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +max_steps: 200 +pretraining_dataset: + path: c4 + name: en +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./model-out + +sequence_len: 2048 +sample_packing: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/examples/tiny-llama/qlora.yml b/examples/tiny-llama/qlora.yml new file mode 100644 index 0000000000..53791985ef --- /dev/null +++ b/examples/tiny-llama/qlora.yml @@ -0,0 +1,66 @@ +base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./qlora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: paged_adamw_32bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 4 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/examples/yi-34B-chat/README.md b/examples/yi-34B-chat/README.md new file mode 100644 index 0000000000..07078850fb --- /dev/null +++ b/examples/yi-34B-chat/README.md @@ -0,0 +1,5 @@ +# Overview + +This is an example of a Yi-34B-Chat configuration. It demonstrates that it is possible to finetune a 34B model on a GPU with 24GB of VRAM. + +Tested on an RTX 4090 with `python -m axolotl.cli.train examples/mistral/qlora.yml`, a single epoch of finetuning on the alpaca dataset using qlora runs in 47 mins, using 97% of available memory. diff --git a/examples/yi-34B-chat/qlora.yml b/examples/yi-34B-chat/qlora.yml new file mode 100644 index 0000000000..0c1a4b7889 --- /dev/null +++ b/examples/yi-34B-chat/qlora.yml @@ -0,0 +1,76 @@ +base_model: 01-ai/Yi-34B-Chat +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_mistral_derived_model: false +is_llama_derived_model: true +load_in_8bit: false +load_in_4bit: true +strict: false +sequence_len: 1024 +bf16: true +fp16: false +tf32: false +flash_attention: true +special_tokens: + bos_token: "<|startoftext|>" + eos_token: "<|endoftext|>" + unk_token: "" + +# Data +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +warmup_steps: 10 + +# Iterations +num_epochs: 1 + +# Evaluation +val_set_size: 0.1 +evals_per_epoch: 5 +eval_table_size: +eval_table_max_new_tokens: 128 +eval_sample_packing: false +eval_batch_size: 1 + +# LoRA +output_dir: ./qlora-out +adapter: qlora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: +lora_target_modules: + +# Sampling +sample_packing: false +pad_to_sequence_len: false + +# Batching +gradient_accumulation_steps: 4 +micro_batch_size: 1 +gradient_checkpointing: true + +# wandb +wandb_project: + +# Optimizer +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +# Misc +train_on_inputs: false +group_by_length: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +debug: +deepspeed: +weight_decay: 0 +fsdp: +fsdp_config: diff --git a/requirements.txt b/requirements.txt index bbee7cf45b..14f6633f7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ auto-gptq==0.5.1 packaging peft==0.6.0 -transformers @ git+https://github.com/huggingface/transformers.git@e5079b0b2abcef11ecbdae60ba4a6636c57b725d +transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0 tokenizers==0.15.0 bitsandbytes>=0.41.1 accelerate==0.24.1 @@ -37,3 +37,5 @@ tensorboard s3fs gcsfs # adlfs + +trl @ git+https://github.com/huggingface/trl.git@main diff --git a/setup.py b/setup.py index 42fd22df11..fe4d2cfad8 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,7 @@ """setup.py for axolotl""" +from importlib.metadata import PackageNotFoundError, version + from setuptools import find_packages, setup @@ -22,12 +24,13 @@ def parse_requirements(): # Handle standard packages _install_requires.append(line) - # TODO(wing) remove once xformers release supports torch 2.1.0 - if "torch==2.1.0" in _install_requires: - _install_requires.pop(_install_requires.index("xformers>=0.0.22")) - _install_requires.append( - "xformers @ git+https://github.com/facebookresearch/xformers.git@main" - ) + try: + torch_version = version("torch") + if torch_version.startswith("2.1.1"): + _install_requires.pop(_install_requires.index("xformers==0.0.22")) + _install_requires.append("xformers==0.0.23") + except PackageNotFoundError: + pass return _install_requires, _dependency_links diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 8ca4f7fe55..0477ebebfb 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -2,6 +2,7 @@ import importlib import logging +import math import os import random import sys @@ -16,6 +17,7 @@ # add src to the pythonpath so we don't need to pip install this from accelerate.commands.config import config_args from art import text2art +from datasets import concatenate_datasets, load_dataset from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer @@ -103,15 +105,7 @@ def do_inference( importlib.import_module("axolotl.prompters"), prompter ) - if cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id - - set_model_mem_id(model, tokenizer) - model.set_mem_cache_args( - max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None - ) - - model = model.to(cfg.device) + model = model.to(cfg.device, dtype=cfg.torch_dtype) while True: print("=" * 80) @@ -176,15 +170,7 @@ def do_inference_gradio( importlib.import_module("axolotl.prompters"), prompter ) - if cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id - - set_model_mem_id(model, tokenizer) - model.set_mem_cache_args( - max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None - ) - - model = model.to(cfg.device) + model = model.to(cfg.device, dtype=cfg.torch_dtype) def generate(instruction): if not instruction: @@ -341,6 +327,94 @@ def load_datasets( ) +def load_rl_datasets( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, # pylint: disable=unused-argument +) -> TrainDatasetMeta: + train_datasets: List[Any] = [] + for i, ds_cfg in enumerate(cfg.datasets): + train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"])) + # eval_dataset = load_dataset( + # cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"] + # ) + eval_dataset = None + + def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen_response']}<|im_end|>" + sample["rejected"] = f"{sample['rejected_response']}<|im_end|>" + return sample + + def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + def apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" + sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" + return sample + + for i, data_set in enumerate(train_datasets): + _type = cfg.datasets[i]["type"] + ds_type_fn = locals()[_type] + train_datasets[i] = data_set.map(ds_type_fn) + train_dataset = concatenate_datasets(train_datasets) + + # eval_dataset = eval_dataset.map(intel_apply_chatml) + + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) + + def check_accelerate_default_config(): if Path(config_args.default_yaml_config_file).exists(): LOG.warning( diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 0caee4c28b..4c810d5722 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -18,7 +18,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs): return_remaining_strings=True ) parsed_cli_args.merge_lora = True - parsed_cfg = load_cfg(config, merge_lora=True, **kwargs) + + parsed_cfg = load_cfg( + config, + merge_lora=True, + load_in_8bit=False, + load_in_4bit=False, + flash_attention=False, + **kwargs + ) do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 81307b6b92..2248784dff 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -12,6 +12,7 @@ check_user_token, load_cfg, load_datasets, + load_rl_datasets, print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs @@ -30,7 +31,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs): parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) - dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if parsed_cfg.rl: + dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + else: + dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ccd9d37c0d..b75766d043 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -9,7 +9,7 @@ import sys from abc import abstractmethod from dataclasses import dataclass, field -from functools import partial +from functools import wraps from pathlib import Path from typing import Optional @@ -20,6 +20,7 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_utils import seed_worker +from trl import DPOTrainer from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( @@ -59,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "Use quadratic warmup for cosine scheduling."}, ) + pretraining: bool = field( + default=False, + metadata={ + "help": "Indicates to trainer whether we are doing continued pretraining." + }, + ) sample_packing: bool = field( default=False, metadata={"help": "Use sample packing for efficient training."}, @@ -120,6 +127,7 @@ class AxolotlTrainer(Trainer): """ args = None # type: AxolotlTrainingArguments + tag_names = ["axolotl"] def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs): self.num_epochs = num_epochs @@ -155,7 +163,7 @@ def create_scheduler( return self.lr_scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing: + if self.args.sample_packing and not self.args.pretraining: return MultipackBatchSampler( RandomSampler(self.train_dataset), self.args.train_batch_size, @@ -191,7 +199,7 @@ def _get_eval_sampler( return super()._get_eval_sampler(eval_dataset) def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing: + if self.args.sample_packing and not self.args.pretraining: train_dataset = self.train_dataset train_dataset = train_dataset.remove_columns(["length"]) data_collator = self.data_collator @@ -290,12 +298,41 @@ def compute_loss(self, model, inputs, return_outputs=False): # return (loss, outputs) if return_outputs else loss return super().compute_loss(model, inputs, return_outputs=return_outputs) + def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None): + if isinstance(tag_names, str): + tag_names = [tag_names] + + if kwargs is not None: + if "tags" not in kwargs: + kwargs["tags"] = tag_names + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].extend(tag_names) + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + tag_names.append(kwargs["tags"]) + kwargs["tags"] = tag_names + + return kwargs + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, *args, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tags when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = self._sanitize_kwargs_for_tagging( + tag_names=self.tag_names, kwargs=kwargs + ) + + return super().push_to_hub(*args, **kwargs) + class AxolotlMambaTrainer(AxolotlTrainer): """ Mamba specific trainer to handle loss calculation """ + tag_names = ["axolotl", "mamba"] + def compute_loss( self, model, @@ -322,6 +359,8 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer): Trainer subclass that uses the OneCycleLR scheduler """ + tag_names = ["axolotl", "onecycle"] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.lr_scheduler = None @@ -351,6 +390,8 @@ class ReLoRATrainer(AxolotlTrainer): Trainer subclass that uses the OneCycleLR scheduler """ + tag_names = ["axolotl", "relora"] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.lr_scheduler = None @@ -386,12 +427,21 @@ class TrainerBuilderBase(abc.ABC): _train_dataset = None _eval_dataset = None + _model_ref = None def __init__(self, cfg, model, tokenizer): self.cfg = cfg self.model = model self.tokenizer = tokenizer + @property + def model_ref(self): + return self._model_ref + + @model_ref.setter + def model_ref(self, model): + self._model_ref = model + @property def train_dataset(self): return self._train_dataset @@ -532,6 +582,14 @@ def build(self, total_num_steps): training_arguments_kwargs[ "gradient_checkpointing" ] = self.cfg.gradient_checkpointing + if self.cfg.gradient_checkpointing_kwargs: + training_arguments_kwargs[ + "gradient_checkpointing_kwargs" + ] = self.cfg.gradient_checkpointing_kwargs + else: + training_arguments_kwargs["gradient_checkpointing_kwargs"] = { + "use_reentrant": False + } if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: @@ -692,6 +750,9 @@ def build(self, total_num_steps): and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine" ) + training_arguments_kwargs["lr_scheduler_kwargs"] = ( + self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} + ) training_arguments_kwargs["weight_decay"] = ( self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 ) @@ -712,6 +773,7 @@ def build(self, total_num_steps): training_arguments_kwargs ) training_arguments_kwargs["model_type"] = self.cfg.model_config_type + training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.neftune_noise_alpha is not None: training_arguments_kwargs[ @@ -743,26 +805,6 @@ def build(self, total_num_steps): # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html data_collator_kwargs["pad_to_multiple_of"] = 64 - if self.cfg.is_llama_derived_model and self.cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import ( - add_mem_tokens, - get_mem_id, - set_model_mem_id, - ) - - set_model_mem_id(self.model, self.tokenizer) - - LOG.info("Adding landmark attention tokens to dataset") - - for dataset in [self.train_dataset, self.eval_dataset]: - dataset = dataset.map( - partial( - add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer) - ), - batched=False, - num_proc=32, - ) - trainer_cls = self._get_trainer_cls() trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls @@ -772,7 +814,7 @@ def build(self, total_num_steps): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - data_collator=self.build_collator(**data_collator_kwargs), + # data_collator=self.build_collator(**data_collator_kwargs), bench_data_collator=transformers.DataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", @@ -802,3 +844,96 @@ def build_collator(self, **kwargs): return_tensors="pt", **kwargs, ) + + +class HFDPOTrainerBuilder(TrainerBuilderBase): + """ + Trainer factory class for DPO Trainer + """ + + def get_callbacks(self): + callbacks = [] + return callbacks + + def get_post_trainer_create_callbacks(self, trainer): + callbacks = [] + return callbacks + + def build_training_arguments(self, total_num_steps): + training_args_kwargs = {} + for arg in [ + "adam_beta1", + "adam_beta2", + "adam_epsilon", + "dataloader_num_workers", + "dataloader_pin_memory", + ]: + if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: + training_args_kwargs[arg] = getattr(self.cfg, arg) + training_args = TrainingArguments( + per_device_train_batch_size=self.cfg.micro_batch_size, + max_steps=total_num_steps, + remove_unused_columns=False, + gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, + learning_rate=self.cfg.learning_rate, + evaluation_strategy="no", + # eval_steps=self.cfg.eval_steps, + save_strategy="steps", + save_steps=self.cfg.save_steps, + output_dir=self.cfg.output_dir, + warmup_steps=self.cfg.warmup_steps, + bf16=True, + gradient_checkpointing=self.cfg.gradient_checkpointing, + gradient_checkpointing_kwargs={"use_reentrant": False}, + logging_first_step=True, + logging_steps=1, + optim=self.cfg.optimizer, + save_total_limit=self.cfg.save_total_limit or 5, + **training_args_kwargs, + ) + + return training_args + + def build(self, total_num_steps): + training_args = self.build_training_arguments(total_num_steps) + dpo_trainer_kwargs = {} + if self.cfg.rl == "ipo": + dpo_trainer_kwargs["loss_type"] = "ipo" + if self.cfg.dpo_label_smoothing: + dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing + + dpo_trainer = DPOTrainer( + self.model, + self.model_ref, + args=training_args, + beta=self.cfg.dpo_beta or 0.1, + train_dataset=self.train_dataset, + # eval_dataset=self.eval_dataset, + eval_dataset=None, + tokenizer=self.tokenizer, + max_length=self.cfg.sequence_len, + max_target_length=None, + max_prompt_length=self.cfg.sequence_len, + generate_during_eval=True, + **dpo_trainer_kwargs, + ) + + return dpo_trainer + + +class HFPPOTrainerBuilder(TrainerBuilderBase): + """ + HF Factory class for PPO Trainer + """ + + def get_callbacks(self): + callbacks = [] + return callbacks + + def get_post_trainer_create_callbacks(self, trainer): + callbacks = [] + return callbacks + + def build(self, total_num_steps): + # build PPOConfig + pass diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py new file mode 100644 index 0000000000..24c0b04123 --- /dev/null +++ b/src/axolotl/core/trainers/trl.py @@ -0,0 +1,66 @@ +""" +module for TRL PPO training +""" +import torch +from tqdm import tqdm +from trl import PPOTrainer + + +class TRLPPOTrainer(PPOTrainer): + """ + wrapper for ppo trainer to handle customizations + """ + + def train( + self, + reward_pipe, + resume_from_checkpoint=None, # pylint: disable=unused-argument + ): + generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": self.tokenizer.eos_token_id, + "max_new_tokens": 32, + } + sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "batch_size": 16, + } + + for epoch, batch in tqdm( # pylint: disable=unused-variable + enumerate(self.dataloader) + ): + query_tensors = batch["input_ids"] + + # generate model response + response_tensors, ref_response_tensors = self.generate( + query_tensors, + return_prompt=False, + generate_ref_response=True, + **generation_kwargs + ) + batch["response"] = self.tokenizer.batch_decode(response_tensors) + batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors) + + # Compute sentiment score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = reward_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] + ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])] + ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs) + ref_rewards = [ + torch.tensor(output[1]["score"]) for output in ref_pipe_outputs + ] + batch["ref_rewards"] = ref_rewards + + # Run PPO step + stats = self.step(query_tensors, response_tensors, rewards) + self.log_stats( + stats, + batch, + rewards, + columns_to_log=["query", "response", "ref_response", "ref_rewards"], + ) diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index 19313fb7e2..aafdabe547 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -82,15 +82,44 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - if self.sep_style == SeparatorStyle.LLAMA2: - seps = [self.sep, self.sep2] + if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": if self.system_message: + if self.messages: + # For llama, the system message is incorporated into the first human instruction + first_role, first_msg = self.messages[0] + if first_role == self.roles[0]: + system_prompt += first_msg + self.messages.pop(0) yield "", system_prompt - else: - yield "", "[INST] " - for i, (role, message) in enumerate(self.messages[1:]): + for i, (role, message) in enumerate(self.messages): if message: - yield role + " ", message + seps[i % 2] + if (i % 2 == 0 and not self.system_message) or ( + i % 2 != 0 and self.system_message + ): + role = " " + role + yield role + " ", message + else: + yield role, "" + return + if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral": + contains_sys_msg = False + if self.system_message: + contains_sys_msg = True + if self.messages: + # There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline + first_role, first_msg = self.messages[0] + if first_role == self.roles[0]: + system_prompt = self.system_template.format( + system_message=" " + self.system_message + ) + system_prompt += first_msg + self.messages.pop(0) + yield "", system_prompt + for i, (role, message) in enumerate(self.messages): + if message and i == 0 and not contains_sys_msg: + yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a ` [INST]` at the beginning of the first instruction. + elif message: + yield role + " ", message else: yield role, "" return @@ -118,6 +147,15 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + "\n", "" return + if self.sep_style == SeparatorStyle.CHATGLM3: + if self.system_message: + yield "", system_prompt + for role, message in self.messages: + if message: + yield role + "\n", " " + message + else: + yield role + return if self.sep_style == SeparatorStyle.CHATINTERN: # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 seps = [self.sep, self.sep2] diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py deleted file mode 100644 index 24a98305f3..0000000000 --- a/src/axolotl/monkeypatch/llama_landmark_attn.py +++ /dev/null @@ -1,1249 +0,0 @@ -# pylint: skip-file -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -PyTorch LLaMA model. -Taken from https://github.com/epfml/landmark-attention/blob/main/llama/llama_mem.py and modified. -""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss -from transformers import LlamaTokenizer -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import ( - LLAMA_INPUTS_DOCSTRING, - LLAMA_START_DOCSTRING, - LlamaMLP, - LlamaPreTrainedModel, - LlamaRMSNorm, - LlamaRotaryEmbedding, - _expand_mask, - _make_causal_mask, - rotate_half, -) -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) - -LOG = logging.getLogger("axolotl") - -_CONFIG_FOR_DOC = "LlamaConfig" - -MEM_TOKEN = "" # nosec - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - if q is None: - q_embed = None - else: - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LandmarkGroupedSoftmaxFunction(torch.autograd.Function): - """ - Landmark grouped softmax function. - """ - - # Note that forward, setup_context, and backward are @staticmethods - @staticmethod - def forward(ctx, x, dim, mem_cnt, resp_mem_idx): - new_shape = list(x.shape) - new_shape[dim] = mem_cnt # max_mem_cnt.item() - max_by_group = x.new_zeros((*new_shape,)) - max_by_group.scatter_reduce_( - src=x, index=resp_mem_idx, dim=dim, reduce="amax", include_self=False - ) - - maxes = torch.gather(max_by_group, dim, resp_mem_idx) - # x_exp = torch.exp(x - torch.where(torch.isinf(maxes), 0, maxes)) - x_exp = torch.exp((x - maxes).to(torch.float32)) - - cumsum_by_group = torch.zeros_like(max_by_group, dtype=x_exp.dtype) - - cumsum_by_group.scatter_add_( - dim, - resp_mem_idx, - x_exp, - ) - denom = torch.gather(cumsum_by_group, dim, resp_mem_idx) - - # probs = torch.where(denom < 0.5, 0, x_exp / denom) - probs = x_exp / denom - - ctx.mem_cnt = mem_cnt - ctx.dim = dim - ctx.save_for_backward(resp_mem_idx, probs) - - return probs - - @staticmethod - def backward(ctx, grad_probs): - mem_cnt = ctx.mem_cnt - dim = ctx.dim - resp_mem_idx, probs = ctx.saved_tensors - grad_x = grad_dim = grad_mem_cnt = grad_resp_mem_idx = None - - if ctx.needs_input_grad[0] or ctx.needs_input_grad[4]: - grad_pair = grad_probs * probs - - new_shape = list(probs.shape) - new_shape[dim] = mem_cnt # max_mem_cnt.item() - cumsum_by_group = grad_pair.new_zeros((*new_shape,)) - cumsum_by_group.scatter_add_(dim, resp_mem_idx, grad_pair) - - if ctx.needs_input_grad[0]: - grad_sum = torch.gather(cumsum_by_group, dim, resp_mem_idx) - grad_x = grad_pair - probs * grad_sum - assert not ctx.needs_input_grad[1] - assert not ctx.needs_input_grad[2] - assert not ctx.needs_input_grad[3] - - return grad_x, grad_dim, grad_mem_cnt, grad_resp_mem_idx - - -def landmark_grouped_softmax(x, dim, is_mem, last_section_mask): - last_and_rest_mask = last_section_mask # | mask - - full_access_mask = is_mem | last_and_rest_mask - - max_mem_cnt = 16 - mem_group_idx = torch.cumsum(is_mem, dim=dim) - mem_bucket_id = max_mem_cnt - 1 - resp_mem_idx = torch.where( - last_and_rest_mask, - max_mem_cnt - 1, - torch.where(is_mem, mem_bucket_id, mem_group_idx), - ) - probs = LandmarkGroupedSoftmaxFunction.apply(x, dim, max_mem_cnt, resp_mem_idx) - - new_shape = list(x.shape) - new_shape[dim] = max_mem_cnt - group_prob = probs.new_zeros((*new_shape,)) - group_prob.scatter_( - dim, torch.where(is_mem, mem_group_idx - 1, max_mem_cnt - 1), probs - ) - probs = probs.mul( - torch.where( - full_access_mask, - last_section_mask, - torch.gather(group_prob, dim, resp_mem_idx), - ) - ) - - return probs - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings - ) - - self.mem_freq = None - self.top_k = None - self.max_cache_size = None - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) - - def set_mem_cache_args(self, mem_freq, top_k, max_cache_size): - self.mem_freq = mem_freq - self.top_k = top_k - self.max_cache_size = max_cache_size - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - is_mem: Optional[torch.Tensor] = None, - last_section_mask: Optional[torch.Tensor] = None, - offload_cache_to_cpu: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = ( - self.q_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - key_states = ( - self.k_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - value_states = ( - self.v_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - if len(past_key_value) > 2: - kv_seq_len += past_key_value[3].shape[2] * past_key_value[3].shape[3] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - key_states_before_pos = key_states - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - - attn_prefix = None - if past_key_value is not None: - # reuse k, v, self_attention - if self.mem_freq is None: - cache_len = past_key_value[0].shape[2] - if self.max_cache_size is not None: - cache_len = min(cache_len, self.max_cache_size) - if is_mem is not None: - is_mem = torch.cat( - (is_mem.new_zeros((1, 1, q_len, cache_len)), is_mem), dim=-1 - ) - last_section_mask = torch.cat( - ( - last_section_mask.new_ones((1, 1, q_len, cache_len)), - last_section_mask, - ), - dim=-1, - ) - - past_key_states = torch.cat([past_key_value[0], key_states], dim=2) - past_value_states = torch.cat([past_key_value[1], value_states], dim=2) - key_states = past_key_states[:, :, -(q_len + cache_len) :] - value_states = past_value_states[:, :, -(q_len + cache_len) :] - expected_att_size = (bsz, self.num_heads, q_len, cache_len + q_len) - else: - orig_value_states = value_states - - incomplete_len = past_key_value[0].shape[2] % (self.mem_freq + 1) - full_len = past_key_value[0].shape[2] - incomplete_len - past_key_mem, past_key_incomplete = torch.split( - past_key_value[0], (full_len, incomplete_len), dim=2 - ) - past_value_mem, past_value_incomplete = torch.split( - past_key_value[1], (full_len, incomplete_len), dim=2 - ) - - if offload_cache_to_cpu: - past_key_value = ( - past_key_incomplete, - past_value_incomplete, - *past_key_value[2:], - ) - - if incomplete_len > 0: - assert q_len + incomplete_len <= (self.mem_freq + 1) - is_mem = torch.cat( - (is_mem.new_zeros((1, 1, q_len, incomplete_len)), is_mem), dim=-1 - ) - last_section_mask = torch.cat( - ( - last_section_mask.new_ones((1, 1, q_len, incomplete_len)), - last_section_mask, - ), - dim=-1, - ) - - if len(past_key_value) > 2: - full_len += past_key_value[3].shape[2] * past_key_value[3].shape[3] - past_key_incomplete_pos = torch.arange( - full_len, - full_len + incomplete_len, - dtype=torch.long, - device=position_ids.device, - ).unsqueeze(0) - _, past_key_incomplete = apply_rotary_pos_emb( - None, past_key_incomplete, cos, sin, past_key_incomplete_pos - ) - key_states = torch.cat((past_key_incomplete, key_states), dim=2) - value_states = torch.cat((past_value_incomplete, value_states), dim=2) - - past_key_mem = past_key_mem.view( - bsz, self.num_heads, -1, self.mem_freq + 1, self.head_dim - ) - past_value_mem = past_value_mem.view( - bsz, self.num_heads, -1, self.mem_freq + 1, self.head_dim - ) - - if len(past_key_value) > 2: - mem_key_nopos = torch.cat( - ( - past_key_value[2], - past_key_mem.select(dim=3, index=self.mem_freq), - ), - dim=2, - ) - past_key_mem_offload = past_key_value[3] - past_key_mem = torch.cat( - ( - past_key_mem_offload, - past_key_mem.to(past_key_mem_offload.device), - ), - dim=2, - ) - past_value_mem = torch.cat( - ( - past_key_value[4], - past_value_mem.to(past_key_mem_offload.device), - ), - dim=2, - ) - else: - mem_key_nopos = past_key_mem.select(dim=3, index=self.mem_freq) - - num_mems = past_key_mem.shape[2] - top_k = min(self.top_k, num_mems) - prefix_len = full_len - (top_k + 1) * (self.mem_freq + 1) - mem_indices = torch.cat( - ( - position_ids.new_zeros((max(0, num_mems - top_k),)), - torch.arange( - 1, - top_k + 1, - device=query_states.device, - dtype=position_ids.dtype, - ), - ), - dim=0, - ) - mem_pos = (mem_indices * (self.mem_freq + 1) + self.mem_freq).unsqueeze( - 0 - ).expand(bsz, -1) + prefix_len - _, mem_key = apply_rotary_pos_emb( - None, mem_key_nopos, cos, sin, mem_pos - ) - mem_attn_weights = torch.matmul( - query_states, mem_key.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if offload_cache_to_cpu: - aggregate = "max_over_tokens" - else: - aggregate = None - if aggregate == "max_over_tokens": - token_retrievers = 1 - head_retrievers = self.num_heads - mem_attn_weights = torch.nn.functional.softmax( - mem_attn_weights, dim=-1 - ) - mem_attn_weights = mem_attn_weights.amax(dim=2, keepdim=True) - elif aggregate is None: - token_retrievers = q_len - head_retrievers = self.num_heads - else: - raise NotImplementedError() - - mem_selected_idx = ( - mem_attn_weights.topk(dim=-1, k=top_k)[1] - .sort(dim=-1)[0] - .view(bsz, head_retrievers, token_retrievers, top_k) - ) - - selected_indices = torch.arange( - 0, - top_k * (self.mem_freq + 1), - device=query_states.device, - dtype=position_ids.dtype, - ) - selected_indices = torch.where( - mem_selected_idx >= num_mems - top_k, self.mem_freq + 1, 0 - ).unsqueeze(-1) + selected_indices.view( - 1, 1, 1, top_k, self.mem_freq + 1 - ) - selected_indices = ( - selected_indices.view( - bsz, head_retrievers, token_retrievers, -1 - ).expand(bsz, self.num_heads, q_len, -1) - + prefix_len - ) - - mem_selected_idx = mem_selected_idx.to(past_key_mem.device) - - mem_selected_idx = mem_selected_idx.view( - bsz, self.num_heads, token_retrievers, top_k, 1, 1 - ).expand( - bsz, - self.num_heads, - token_retrievers, - top_k, - self.mem_freq + 1, - self.head_dim, - ) - selected_keys = past_key_mem.unsqueeze(2).expand( - bsz, - self.num_heads, - token_retrievers, - -1, - self.mem_freq + 1, - self.head_dim, - ) - selected_keys = selected_keys.take_along_dim( - mem_selected_idx, dim=3 - ).to(query_states.device) - selected_values = ( - past_value_mem.unsqueeze(2) - .expand( - bsz, - self.num_heads, - token_retrievers, - -1, - self.mem_freq + 1, - self.head_dim, - ) - .take_along_dim(mem_selected_idx, dim=3) - .to(query_states.device) - ) - - selected_keys = selected_keys.view( - bsz, self.num_heads, token_retrievers, -1, self.head_dim - ).expand(bsz, self.num_heads, q_len, -1, self.head_dim) - selected_keys = apply_rotary_pos_emb( - None, selected_keys.unsqueeze(1), cos, sin, selected_indices - )[1].squeeze(1) - selected_values = selected_values.view( - bsz, self.num_heads, token_retrievers, -1, self.head_dim - ).expand(bsz, self.num_heads, q_len, -1, self.head_dim) - attn_prefix = torch.matmul( - query_states.unsqueeze(3), selected_keys.transpose(3, 4) - ).squeeze(3) / math.sqrt(self.head_dim) - is_mem_prefix = ( - torch.cat( - (is_mem.new_zeros((self.mem_freq,)), is_mem.new_ones((1,))) - ) - .unsqueeze(0) - .repeat((top_k, 1)) - ) - is_mem_prefix = is_mem_prefix.view(1, 1, 1, -1).expand(1, 1, q_len, -1) - is_mem = torch.cat((is_mem_prefix, is_mem), dim=-1) - last_section_mask = torch.cat( - ( - last_section_mask.new_zeros( - (1, 1, q_len, top_k * (self.mem_freq + 1)) - ), - last_section_mask, - ), - dim=-1, - ) - expected_att_size = (bsz, self.num_heads, q_len, q_len + incomplete_len) - - past_key_states = torch.cat( - [past_key_value[0], key_states_before_pos], dim=2 - ) - past_value_states = torch.cat( - [past_key_value[1], orig_value_states], dim=2 - ) - - if offload_cache_to_cpu: - past_key_value = ( - ( - past_key_states, - past_value_states, - mem_key_nopos, - past_key_mem.to("cpu"), - past_value_mem.to("cpu"), - *past_key_value[5:], - ) - if use_cache - else None - ) - else: - past_key_value = ( - (past_key_states, past_value_states) if use_cache else None - ) - - else: - if self.mem_freq is None: - past_key_states = key_states - else: - past_key_states = key_states_before_pos - past_value_states = value_states - expected_att_size = (bsz, self.num_heads, q_len, kv_seq_len) - past_key_value = (past_key_states, past_value_states) if use_cache else None - - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - if attn_weights.size() != expected_att_size: - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask[..., -attn_weights.shape[-1] :] - attn_weights = torch.max( - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) - ) - if attn_prefix is not None: - attn_weights = torch.cat((attn_prefix, attn_weights), dim=-1) - # upcast attention to fp32 - if is_mem is None: - raise ValueError("Don't use this without landmarks") - - attn_weights = landmark_grouped_softmax( - attn_weights, - dim=-1, - is_mem=is_mem.expand(-1, self.num_heads, -1, -1), - last_section_mask=last_section_mask, - ).to(query_states.dtype) - - if attn_prefix is not None: - attn_prefix, attn_weights = torch.split( - attn_weights, - (attn_prefix.shape[-1], attn_weights.shape[-1] - attn_prefix.shape[-1]), - dim=-1, - ) - attn_output = torch.matmul(attn_weights, value_states) - if attn_prefix is not None: - attn_output += torch.matmul( - attn_prefix.unsqueeze(3), selected_values - ).squeeze(3) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - """ - Llama Decoder layer - """ - - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - def set_mem_cache_args(self, mem_freq, top_k, max_cache_size): - self.self_attn.set_mem_cache_args(mem_freq, top_k, max_cache_size) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - is_mem: Optional[torch.Tensor] = None, - last_section_mask: Optional[torch.Tensor] = None, - offload_cache_to_cpu: bool = False, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - is_mem=is_mem, - last_section_mask=last_section_mask, - offload_cache_to_cpu=offload_cache_to_cpu, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx - ) - self.layers = nn.ModuleList( - [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] - ) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.mem_id = None - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def set_mem_id(self, mem_id): - self.mem_id = mem_id - - def set_mem_cache_args(self, mem_freq, top_k, max_cache_size): - for layer in self.layers: - layer.set_mem_cache_args(mem_freq, top_k, max_cache_size) - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - offload_cache_to_cpu: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - is_mem = None - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - if self.mem_id is not None: - with torch.no_grad(): - is_mem = input_ids == self.mem_id - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - if self.mem_id is not None: - raise NotImplementedError - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - if is_mem is not None: - pass - # raise NotImplementedError - past_key_values_length = past_key_values[0][0].shape[2] - if len(past_key_values[0]) > 2: - past_key_values_length += ( - past_key_values[0][3].shape[2] * past_key_values[0][3].shape[3] - ) - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - - last_section_mask = None - if is_mem is not None: - is_mem = is_mem.unsqueeze(1).unsqueeze(2) - current_len = input_ids.shape[1] - mem_ids = torch.where( - attention_mask[..., -current_len:] < -1, - 0, - torch.cumsum(is_mem, -1) - is_mem.int(), - ) - last_section_mask = torch.amax(mem_ids, -1, keepdim=True) == mem_ids - attention_mask[..., -current_len:].masked_fill_( - last_section_mask & is_mem, - torch.tensor( - torch.finfo(inputs_embeds.dtype).min, device=inputs_embeds.device - ), - ) - last_section_mask.logical_and_(attention_mask[..., -current_len:] > -1) - is_mem = is_mem.logical_and(attention_mask[..., -current_len:] > -1) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - LOG.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - output_attentions, - None, - is_mem, - last_section_mask, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - is_mem=is_mem, - last_section_mask=last_section_mask, - offload_cache_to_cpu=offload_cache_to_cpu, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - """ - Llama model with a causal language modeling head. - """ - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - self.mem_id = None - self.mem_freq = None - self.top_k = None - self.max_seq_len = None - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - offload_cache_to_cpu: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - window_len = self.max_seq_len or input_ids.shape[1] - last_logits = None - for _, idx in enumerate(range(0, input_ids.shape[1], window_len)): - if idx >= 1: - if output_attentions or output_hidden_states: - raise NotImplementedError - if not use_cache: - raise NotImplementedError - outputs = self.model( - input_ids=input_ids[:, idx : idx + window_len], - attention_mask=attention_mask[ - :, : idx + window_len + attention_mask.shape[1] - input_ids.shape[1] - ] - if attention_mask is not None - else None, - position_ids=position_ids[:, idx : idx + window_len] - if position_ids is not None - else None, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds[:, idx : idx + window_len] - if inputs_embeds is not None - else None, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - offload_cache_to_cpu=offload_cache_to_cpu, - ) - past_key_values = outputs.past_key_values - if last_logits is not None: - last_logits = torch.cat((last_logits, outputs[0]), dim=-2) - last_logits = outputs[0] - - hidden_states = last_logits - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def set_mem_id(self, mem_id): - self.mem_id = mem_id - self.model.set_mem_id(mem_id) - - def set_mem_cache_args(self, max_seq_len, mem_freq, top_k, max_cache_size): - self.mem_freq = mem_freq - self.top_k = top_k - self.max_seq_len = max_seq_len - if self.max_seq_len is not None: - assert self.max_seq_len % (self.mem_freq + 1) == 0 - self.model.set_mem_cache_args(mem_freq, top_k, max_cache_size) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - total_len = input_ids.shape[1] - if past_key_values: - prev_len = input_ids.shape[1] - 1 - else: - prev_len = 0 - - position_ids = kwargs.get("position_ids", None) - - if self.mem_freq is not None: - if position_ids is not None: - raise NotImplementedError - # T = input_ids.shape[1] - - prev_incomplete_len = prev_len % self.mem_freq - prev_complete_len = prev_len - prev_incomplete_len - incomplete_len = total_len % self.mem_freq - new_full_len = total_len - prev_complete_len - incomplete_len - - prev_input, input_ids_with_mem, input_ids_without_mem = torch.split( - input_ids, (prev_complete_len, new_full_len, incomplete_len), dim=-1 - ) - - bsz, _ = input_ids.size() - input_ids_with_mem = input_ids_with_mem.view(bsz, -1, self.mem_freq) - input_ids_with_mem = torch.cat( - ( - input_ids_with_mem, - input_ids_with_mem.new_full( - (bsz, input_ids_with_mem.shape[1], 1), self.mem_id - ), - ), - dim=-1, - ).view(bsz, -1) - input_ids = torch.cat( - (prev_input, input_ids_with_mem, input_ids_without_mem), dim=-1 - ) - if attention_mask is not None: - attention_mask_with_mem, attention_mask_without_mem = torch.split( - attention_mask, - (prev_complete_len + new_full_len, incomplete_len), - dim=-1, - ) - attention_mask_with_mem = attention_mask_with_mem.view( - bsz, -1, self.mem_freq - ) - attention_mask_with_mem = torch.cat( - ( - attention_mask_with_mem, - attention_mask_with_mem.new_ones( - (bsz, attention_mask_with_mem.shape[1], 1) - ), - ), - dim=-1, - ).view(bsz, -1) - attention_mask = torch.cat( - (attention_mask_with_mem, attention_mask_without_mem), dim=-1 - ) - - input_ids = input_ids[:, prev_len:] - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids[:, -input_ids.shape[1] :].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if ( - inputs_embeds is not None - and past_key_values is None - and self.mem_freq is None - ): - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "offload_cache_to_cpu": kwargs.get("offload_cache_to_cpu"), - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) for past_state in layer_past - ), - ) - return reordered_past - - -def add_mem_tokens(example, mem_freq, mem_id): - ids = example["input_ids"] - ret = [] - prev_idx = 0 - for t_idx in range(mem_freq, len(ids), mem_freq): - ret.extend(ids[prev_idx:t_idx]) - ret.append(mem_id) - prev_idx = t_idx - ret.extend(ids[prev_idx:]) - # drop attention_mask - return {"input_ids": ret} - - -def patch_llama_with_landmark_attn(): - import transformers - - transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM - transformers.models.llama.modeling_llama.LlamaModel = LlamaModel - transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention - transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer - transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb - - -def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer): - mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN) - model.set_mem_id(mem_id) - - -def get_mem_id(tokenizer: LlamaTokenizer): - return tokenizer.convert_tokens_to_ids(MEM_TOKEN) diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index 4188146892..74fa00f649 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -17,6 +17,6 @@ def replace_mixtral_attn_with_multipack_flash_attn(): transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = ( mixtral_model_forward ) - transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[ + transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[ "flash_attention_2" ] = MixtralMultipackFlashAttention2 diff --git a/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py b/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py index 34f35015f9..db892530d6 100644 --- a/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py +++ b/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py @@ -261,7 +261,11 @@ def mixtral_model_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if ( + attention_mask is not None + and self._attn_implementation == "flash_attention_2" + and use_cache + ): is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -270,7 +274,7 @@ def mixtral_model_forward( " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = ( attention_mask diff --git a/src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py b/src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py deleted file mode 100644 index 4cbbd4f479..0000000000 --- a/src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py +++ /dev/null @@ -1,94 +0,0 @@ -# pylint: skip-file -""" -Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py -""" -import torch -import transformers -import transformers.models.llama.modeling_llama -from einops import rearrange - - -class XposRotaryEmbedding(torch.nn.Module): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scale_base=2048, - use_xpos=True, - ): - super().__init__() - self.max_seq_len_cached = max_position_embeddings - self.scale_base = scale_base - - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq) - freqs = torch.einsum("i , j -> i j", t, inv_freq) - freqs = torch.cat((freqs, freqs), dim=-1) - - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.register_buffer("freqs_cached", freqs, persistent=False) - - if not use_xpos: - self.register_buffer("scale", None) - self.register_buffer("scale_cached", torch.ones(1)) - return - - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) - power = (t - (self.max_seq_len_cached // 2)) / self.scale_base - scale_cached = scale ** rearrange(power, "n -> n 1") - scale_cached = torch.cat((scale_cached, scale_cached), dim=-1) - - self.register_buffer("scale", scale, persistent=False) - self.register_buffer("scale_cached", scale_cached, persistent=False) - - def forward( - self, - x, - seq_len, - ): - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device).type_as( - self.inv_freq - ) - freqs = torch.einsum("i , j -> i j", t, self.inv_freq) - freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype) - - self.register_buffer("freqs_cached", freqs) - - if self.scale is None: - self.register_buffer( - "scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype) - ) - - return self.freqs_cached.to(dtype=x.dtype), self.scale_cached - - power = (t - (seq_len // 2)) / self.scale_base - scale = self.scale ** rearrange(power, "n -> n 1") - scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype) - self.register_buffer("scale_cached", scale) - - return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype) - - -def rotate_half(x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None): - freqs = freqs[position_ids, :] - if scale.shape[-1] != 1: - scale = scale[position_ids, :] - - q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale) - k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale) - - return q_embed, k_embed - - -def replace_llama_rope_with_xpos_rope(): - transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding - transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index fbb44ccfae..c026889682 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -39,6 +39,23 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): return strategy +def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + conversation = ( + ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None + ) + strategy = UltrachatShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation=conversation, + ), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + if ds_cfg and "strict" in ds_cfg: + strategy.strict = ds_cfg["strict"] + return strategy + + def load_role(tokenizer, cfg): return SimpleRoleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2(), @@ -109,3 +126,17 @@ def get_conversation_thread(self, prompt): {"from": role_map[t["role"]], "value": t["text"]} for t in conversations ] return turns + + +class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): + """ + sharegpt strategy that remaps ultrachat data to sharegpt format + """ + + def get_conversation_thread(self, prompt): + conversations = prompt["messages"] + role_map = {"user": "human", "assistant": "gpt"} + turns = [ + {"from": role_map[t["role"]], "value": t["content"]} for t in conversations + ] + return turns diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 169fc51272..e0da112528 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -12,6 +12,7 @@ from accelerate.logging import get_logger from datasets import Dataset from optimum.bettertransformer import BetterTransformer +from pkg_resources import get_distribution # type: ignore from transformers.deepspeed import is_deepspeed_zero3_enabled from axolotl.common.cli import TrainerCliArgs @@ -60,6 +61,12 @@ def train( msg += " and peft_config..." LOG.debug(msg) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) + model_ref = None + if cfg.rl: + # load the model again for model_ref/baseline + model_ref, _ = load_model( + cfg, tokenizer, inference=cli_args.inference, reference_model=True + ) safe_serialization = cfg.save_safetensors is True @@ -82,7 +89,7 @@ def train( freeze_parameters_except(model, cfg.unfrozen_parameters) trainer = setup_trainer( - cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps + cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps ) if hasattr(model, "config"): @@ -115,6 +122,12 @@ def terminate_handler(_, __, model): badge_markdown = """[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)""" transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" + if getattr(cfg, "axolotl_config_path"): + raw_axolotl_cfg = Path(cfg.axolotl_config_path) + version = get_distribution("axolotl").version + if raw_axolotl_cfg.is_file(): + transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n
See axolotl config\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n

\n" + LOG.info("Starting trainer...") if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length") diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 8599c0df0f..122cd92ede 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -4,6 +4,8 @@ import logging import os +from shutil import copyfile +from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Dict, List import evaluate @@ -561,10 +563,15 @@ def on_train_begin( ): if is_main_process(): try: - artifact = wandb.Artifact(name="axolotl-config", type="config") - artifact.add_file(local_path=self.axolotl_config_path) - wandb.run.log_artifact(artifact) - LOG.info("Axolotl config has been saved to WandB as an artifact.") + # sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later. + with NamedTemporaryFile( + mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" + ) as temp_file: + copyfile(self.axolotl_config_path, temp_file.name) + wandb.save(temp_file.name) + LOG.info( + "The Axolotl config has been saved to the WandB run under files." + ) except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to WandB: {err}") return control diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py new file mode 100644 index 0000000000..459da44007 --- /dev/null +++ b/src/axolotl/utils/chat_templates.py @@ -0,0 +1,29 @@ +""" +This module provides functionality for selecting chat templates based on user choices. +These templates are used for formatting messages in a conversation. +""" + + +def chat_templates(user_choice: str): + """ + Finds the correct chat_template for the tokenizer_config. + + Args: + user_choice (str): The user's choice of template. + + Returns: + str: The chosen template string. + + Raises: + ValueError: If the user_choice is not found in the templates. + """ + + templates = { + "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. + "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + } + + if user_choice in templates: + return templates[user_choice] + + raise ValueError(f"Template '{user_choice}' not found.") diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index 0f0eb5a95a..b9c1c3b3c1 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -178,3 +178,24 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: "input_ids": input_ids, "labels": labels, } + + +@dataclass +class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + """ + Collator for multipack specific to the using the BatchSampler + """ + + def __call__(self, features, return_tensors=None): + chunked_data = {} + for feature in features.keys(): + if feature == "length": + continue + if feature == "attention_mask": + arrays = [(1) * np.array(item) for item in features[feature]] + chunked_data[feature] = np.concatenate(arrays) + else: + arrays = [np.array(item) for item in features[feature]] + chunked_data[feature] = np.concatenate(arrays) + features = [chunked_data] + return super().__call__(features, return_tensors=return_tensors) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 1b4ce92465..9bade45728 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -422,11 +422,6 @@ def validate_config(cfg): if cfg.warmup_steps and cfg.warmup_ratio: raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") - if cfg.is_qwen_derived_model and cfg.gradient_checkpointing: - LOG.warning( - "Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch." - ) - if cfg.wandb_run_id and not cfg.wandb_name: cfg.wandb_name = cfg.wandb_run_id @@ -448,6 +443,20 @@ def validate_config(cfg): if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0: raise ValueError("neftune_noise_alpha must be > 0.0") + if ( + cfg.adapter + and cfg.tokens + and ( + not cfg.lora_modules_to_save + or not all( + x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"] + ) + ) + ): + raise ValueError( + "lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`." + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 5c41d16fe4..40a3306021 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -2,6 +2,7 @@ import functools import hashlib import logging +from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple, Union @@ -14,6 +15,7 @@ load_from_disk, ) from huggingface_hub import hf_hub_download +from torch.utils.data import RandomSampler from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH @@ -39,11 +41,14 @@ SummarizeTLDRPrompter, UnsupportedPrompter, ) +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.samplers.multipack import MultipackBatchSampler from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, + process_pretraining_datasets_for_packing, ) LOG = logging.getLogger("axolotl") @@ -64,9 +69,17 @@ def prepare_dataset(cfg, tokenizer): tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) else: + path = cfg.pretraining_dataset + name = None + if isinstance(cfg.pretraining_dataset, dict): + path = cfg.pretraining_dataset["path"] + name = cfg.pretraining_dataset["name"] + train_dataset = load_pretraining_dataset( - cfg.pretraining_dataset, + path, tokenizer, + cfg, + name=name, max_tokens=cfg.sequence_len, seed=cfg.seed or 42, ) @@ -806,9 +819,24 @@ def encode_pretraining( return ret -def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): - encode = functools.partial(encode_pretraining, tokenizer, max_tokens) - dataset = load_dataset(path, streaming=True, split="train") +def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): + if cfg.sample_packing: + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens + ) + encode = functools.partial( + encode_packed_pretraining, + tokenizer, + collate_fn, + max_seq_length=max_tokens, + batch_size=cfg.micro_batch_size, + ) + # set this to 1 so downstream data_loader doesn't try to increase the batch again + cfg.micro_batch_size = 1 + else: + encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + + dataset = load_dataset(path, streaming=True, split="train", name=name) dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.map( encode, @@ -819,3 +847,63 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): remove_columns=dataset.features.keys(), ) return dataset + + +def encode_packed_pretraining( + tokenizer: PreTrainedTokenizerBase, + collate_fn, + examples: List[str], + max_seq_length: int = 2048, + batch_size: int = 4, +) -> Dict[str, List]: + # pylint: disable=duplicate-code + # tokenize all the examples + # rows get split with stride (overlap) + res = tokenizer( + examples, + truncation=True, + max_length=max_seq_length - 1, + add_special_tokens=True, + return_overflowing_tokens=True, + stride=256, + ) + + input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]] + attention_mask = [seq + [1] for seq in res["attention_mask"]] + + tokenized_examples = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + train_dataset = Dataset.from_dict(tokenized_examples) + train_dataset = process_pretraining_datasets_for_packing( + train_dataset, max_seq_length + ) + + sampler = MultipackBatchSampler( + RandomSampler(train_dataset), + batch_size=batch_size, + drop_last=True, + batch_max_len=batch_size * max_seq_length, + lengths=( + train_dataset.data.column("position_ids") + .to_pandas() + .apply(lambda x: x[-1] + 1) + .values + ), + ) + + chunked_data = defaultdict(list) + + for data in sampler: + features = train_dataset[data] + features["labels"] = features["input_ids"].copy() + collated_features = collate_fn(features) + + for feature in features.keys(): + if feature == "length": + continue + chunked_data[feature].append(collated_features[feature].squeeze(0)) + + return chunked_data diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 022229af85..b30ffcad8c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -26,6 +26,7 @@ from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.chat_templates import chat_templates from axolotl.utils.dict import DictDefault LOG = logging.getLogger("axolotl") @@ -136,6 +137,23 @@ def load_tokenizer(cfg): if cfg.special_tokens: for k, val in cfg.special_tokens.items(): + # check if new special token is not already in tokenizer and + # is adapter training to make sure lora_modules_to_save is set + if ( + (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) + and cfg.adapter + and ( + not cfg.lora_modules_to_save + or not all( + x in cfg.lora_modules_to_save + for x in ["embed_tokens", "lm_head"] + ) + ) + ): + raise ValueError( + "Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens." + ) + tokenizer.add_special_tokens( {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} ) @@ -169,6 +187,12 @@ def load_tokenizer(cfg): LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + if cfg.chat_template: + tokenizer.chat_template = chat_templates(cfg.chat_template) + else: + LOG.info( + "No Chat template selected. Consider adding a chat template for easier inference." + ) return tokenizer @@ -176,6 +200,7 @@ def load_model( cfg: DictDefault, tokenizer: PreTrainedTokenizerBase, inference: bool = False, + reference_model: bool = False, ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: """ Load a model for a given configuration and tokenizer. @@ -230,17 +255,6 @@ def load_model( LOG.info("patching with sdp attention") hijack_llama_sdp_attention() - elif cfg.is_llama_derived_model and cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import ( - MEM_TOKEN, - patch_llama_with_landmark_attn, - ) - - LOG.info("patching with landmark attention") - patch_llama_with_landmark_attn() - - # Note: This might overwrite previous additional_special_tokens - tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]}) if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing: from axolotl.monkeypatch.mistral_attn_hijack_flash import ( @@ -262,14 +276,6 @@ def load_model( LOG.info("patching with flash attention") replace_mixtral_attn_with_multipack_flash_attn() - if cfg.is_llama_derived_model and cfg.xpos_rope: - from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( - replace_llama_rope_with_xpos_rope, - ) - - LOG.info("patching with xpos rope") - replace_llama_rope_with_xpos_rope() - if ( cfg.is_llama_derived_model and (cfg.max_packed_sequence_len or cfg.sample_packing) @@ -285,6 +291,15 @@ def load_model( model_kwargs["device_map"] = cfg.device_map model_kwargs["max_memory"] = cfg.max_memory model_kwargs["torch_dtype"] = cfg.torch_dtype + # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss + # if cfg.rl: + # if torch.cuda.device_count() > 1: + # if reference_model: + # model_kwargs["device_map"] = "cuda:" + str( + # torch.cuda.current_device() + 1 + # ) + # else: + # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) if is_deepspeed_zero3_enabled(): del model_kwargs["device_map"] @@ -303,13 +318,20 @@ def load_model( **model_config.quantization_config ) if cfg.adapter == "qlora" and cfg.load_in_4bit: + bnb_config = { + "load_in_4bit": True, + "llm_int8_threshold": 6.0, + "llm_int8_has_fp16_weight": False, + "bnb_4bit_compute_dtype": cfg.torch_dtype, + "bnb_4bit_use_double_quant": True, + "bnb_4bit_quant_type": "nf4", + } + + if cfg.bnb_config_kwargs: + bnb_config.update(cfg.bnb_config_kwargs) + model_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=cfg.torch_dtype, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", + **bnb_config, ) # sample packing uses custom FA2 patch if cfg.flash_attention: @@ -320,15 +342,18 @@ def load_model( or cfg.is_mistral_derived_model or model_config.model_type == "mixtral" ): + model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" ) else: if model_config.model_type == "mixtral": + model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" ) else: + model_kwargs["attn_implementation"] = "eager" model_config._attn_implementation = ( # pylint: disable=protected-access "eager" ) @@ -545,9 +570,11 @@ def load_model( if hasattr(module, "weight"): module.to(cfg.torch_dtype) - model, lora_config = load_adapter(model, cfg, cfg.adapter) + lora_config = None + if not reference_model or cfg.lora_model_dir: + model, lora_config = load_adapter(model, cfg, cfg.adapter) - if cfg.ddp and not load_in_8bit: + if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): model.to(f"cuda:{cfg.local_rank}") if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f046dd7be8..3139f56004 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -12,7 +12,7 @@ from datasets import set_caching_enabled from torch.utils.data import DataLoader, RandomSampler -from axolotl.core.trainer_builder import HFCausalTrainerBuilder +from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.samplers import MultipackBatchSampler @@ -143,6 +143,16 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): return train_dataset, eval_dataset +def process_pretraining_datasets_for_packing(train_dataset, sequence_len): + drop_long = partial(drop_long_seq, sequence_len=sequence_len) + + train_dataset = train_dataset.filter(drop_long) + train_dataset = train_dataset.map( + add_position_ids, + ) + return train_dataset + + def calculate_total_num_steps(cfg, train_dataset, update=True): if not cfg.total_num_tokens: total_num_tokens = np.sum( @@ -280,7 +290,12 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) + if cfg.rl: + trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder.model_ref = model[1] + else: + trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py new file mode 100644 index 0000000000..e8987ef452 --- /dev/null +++ b/tests/core/test_trainer_builder.py @@ -0,0 +1,59 @@ +""" +unit tests for axolotl.core.trainer_builder +""" +import pytest + +from axolotl.core.trainer_builder import HFDPOTrainerBuilder +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + + +@pytest.fixture(name="cfg") +def fixture_cfg(): + return DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "model_type": "AutoModelForCausalLM", + "tokenizer_type": "LlamaTokenizer", + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 0.00005, + "save_steps": 100, + "output_dir": "./model-out", + "warmup_steps": 10, + "gradient_checkpointing": False, + "optimizer": "adamw_torch", + "sequence_len": 2048, + "rl": True, + "adam_beta1": 0.998, + "adam_beta2": 0.9, + "adam_epsilon": 0.00001, + "dataloader_num_workers": 1, + "dataloader_pin_memory": True, + } + ) + + +@pytest.fixture(name="tokenizer") +def fixture_tokenizer(cfg): + return load_tokenizer(cfg) + + +@pytest.fixture(name="model") +def fixture_model(cfg, tokenizer): + return load_model(cfg, tokenizer) + + +class TestHFDPOTrainerBuilder: + """ + TestCase class for DPO trainer builder + """ + + def test_build_training_arguments(self, cfg, model, tokenizer): + builder = HFDPOTrainerBuilder(cfg, model, tokenizer) + training_arguments = builder.build_training_arguments(100) + assert training_arguments.adam_beta1 == 0.998 + assert training_arguments.adam_beta2 == 0.9 + assert training_arguments.adam_epsilon == 0.00001 + assert training_arguments.dataloader_num_workers == 1 + assert training_arguments.dataloader_pin_memory is True diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py new file mode 100644 index 0000000000..896cc74d0f --- /dev/null +++ b/tests/e2e/test_mixtral.py @@ -0,0 +1,109 @@ +""" +E2E tests for mixtral +""" + +import logging +import os +import unittest +from pathlib import Path + +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestMixtral(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_qlora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 1024, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 16, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_ft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/tests/e2e/test_mixtral_samplepack.py b/tests/e2e/test_mixtral_samplepack.py new file mode 100644 index 0000000000..b43702a512 --- /dev/null +++ b/tests/e2e/test_mixtral_samplepack.py @@ -0,0 +1,123 @@ +""" +E2E tests for mixtral +""" + +import logging +import os +import unittest +from pathlib import Path + +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestMixtral(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_qlora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 2048, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 16, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "sample_packing": True, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_ft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 2048, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "sample_packing": True, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert ( + "axolotl.monkeypatch.mixtral.modeling_mixtral" + in model.model.layers[0].self_attn.__class__.__module__ + ) + assert ( + "MixtralMultipackFlashAttention2" + in model.model.layers[0].self_attn.__class__.__name__ + ) + assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/tests/e2e/test_model_patches.py b/tests/e2e/test_model_patches.py new file mode 100644 index 0000000000..eb11244644 --- /dev/null +++ b/tests/e2e/test_model_patches.py @@ -0,0 +1,99 @@ +""" +E2E smoke tests to check that the monkeypatches are in place for certain configurations +""" + +import unittest + +from axolotl.common.cli import TrainerCliArgs +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + +from .utils import with_temp_dir + + +class TestModelPatches(unittest.TestCase): + """ + TestCases for the multipack monkey patches + """ + + @with_temp_dir + def test_mixtral_multipack(self, temp_dir): + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sample_packing": True, + "sequence_len": 2048, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + tokenizer = load_tokenizer(cfg) + model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + + assert ( + "axolotl.monkeypatch.mixtral.modeling_mixtral" + in model.model.layers[0].self_attn.__class__.__module__ + ) + assert ( + "MixtralMultipackFlashAttention2" + in model.model.layers[0].self_attn.__class__.__name__ + ) + + @with_temp_dir + def test_mistral_multipack(self, temp_dir): + cfg = DictDefault( + { + "base_model": "openaccess-ai-collective/tiny-mistral", + "flash_attention": True, + "sample_packing": True, + "sequence_len": 2048, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + tokenizer = load_tokenizer(cfg) + model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + + assert ( + "axolotl.monkeypatch.mistral_attn_hijack_flash" + in model.model.layers[0].self_attn.forward.__module__ + ) diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py new file mode 100644 index 0000000000..9dee6ed7f0 --- /dev/null +++ b/tests/test_packed_pretraining.py @@ -0,0 +1,87 @@ +"""Module for testing streaming dataset sequence packing""" +import math +import unittest +from functools import partial + +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +from axolotl.utils.collators import DataCollatorForSeq2Seq +from axolotl.utils.data import encode_packed_pretraining + + +class TestPacking(unittest.TestCase): + """ + Test class for packing streaming dataset sequences + """ + + def setUp(self) -> None: + # pylint: disable=duplicate-code + self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + "pad_token": "[PAD]", + } + ) + self.max_seq_length = 8192 + self.batch_size = 6 + self.sample_packing_efficiency = 1 + self.data_collator_kwargs = { + "padding": True, + "pad_to_multiple_of": 64 * math.ceil(self.max_seq_length / 64), + } + + def test_packing_stream_dataset(self): + # pylint: disable=duplicate-code + dataset = load_dataset( + "c4", + "en", + streaming=True, + )["train"] + + encode = partial( + encode_packed_pretraining, + self.tokenizer, + max_seq_length=self.max_seq_length, + sample_packing_efficiency=self.sample_packing_efficiency, + ) + + dataset = dataset.map( + encode, + batched=True, + input_columns="text", + remove_columns=dataset.features.keys(), + ) + + data_collator_fn = DataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + **self.data_collator_kwargs, + ) + + trainer_loader = DataLoader( + dataset, + batch_size=self.batch_size, + collate_fn=data_collator_fn, + drop_last=True, + ) + idx = 0 + for data in trainer_loader: + if idx > 10: + break + assert data["input_ids"].shape == (self.batch_size, self.max_seq_length) + assert data["position_ids"].shape == (self.batch_size, self.max_seq_length) + assert data["labels"].shape == (self.batch_size, self.max_seq_length) + assert data["attention_mask"].shape == ( + self.batch_size, + self.max_seq_length, + ) + idx += 1 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 0635bd718b..cea39d0adf 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -2,6 +2,7 @@ import json import logging import unittest +from copy import deepcopy from pathlib import Path from typing import Optional @@ -25,6 +26,50 @@ LOG = logging.getLogger("axolotl") +test_data = { + "multi_turn_sys": { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + }, + "single_turn_sys": { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + }, + "single_turn_no_sys": { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + }, + "multi_turn_no_sys": { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + }, +} + + +def prompt_strat(conversation, tokenizer): + "Helper function to create a prompt strategy for testing." + prompter = ShareGPTPrompterV2(conversation=conversation) + return ShareGPTPromptTokenizingStrategy( + prompter, + tokenizer, + False, + 2048, + ) + class TestPromptTokenizationStrategies(unittest.TestCase): """ @@ -114,6 +159,70 @@ def test_sharegpt_warnings_turns(self): in self._caplog.records[0].message ) + def test_sharegpt_llama(self): + "Make sure the sharegpt/llama is tokenized and formatted correctly." + strat = prompt_strat("llama-2", self.tokenizer) + + def tokenize(conv): + return strat.tokenize_prompt(deepcopy(conv))["input_ids"] + + def decode(ids): + return strat.tokenizer.decode(ids) + + # fmt: off + # System message, multi-turn conversations + mt_ids = tokenize(test_data['multi_turn_sys']) + assert decode(mt_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum [INST] 123 [/INST] sit' + assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + + # System message, single-turn conversations + st_ids = tokenize(test_data['single_turn_sys']) + assert decode(st_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum' + assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, single-turn + ns_ids = tokenize(test_data['single_turn_no_sys']) + assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' + assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, multi-turn + ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) + assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' + assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + # fmt: on + + def test_sharegpt_mistral(self): + "Make sure the sharegpt/mistral is tokenized and formatted correctly." + strat = prompt_strat("mistral", self.tokenizer) + + def tokenize(conv): + return strat.tokenize_prompt(deepcopy(conv))["input_ids"] + + def decode(ids): + return strat.tokenizer.decode(ids) + + # fmt: off + # System message, multi-turn conversations + mt_ids = tokenize(test_data['multi_turn_sys']) + assert decode(mt_ids) == ' [INST] lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit
' + assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + + # System message, single-turn conversations + st_ids = tokenize(test_data['single_turn_sys']) + assert decode(st_ids) == ' [INST] lorem\nabc [/INST] ipsum' + assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, single-turn + ns_ids = tokenize(test_data['single_turn_no_sys']) + assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' + assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, multi-turn + ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) + assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit
' + assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + # fmt: on + def test_sharegpt_changes_roles(self): conversation = { "roles": ["USER", "CHARACTER"], diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index 5c83391942..bfe4f06af9 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -3,6 +3,8 @@ """ import unittest +import pytest + from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_tokenizer @@ -31,6 +33,40 @@ def test_dont_use_fast(self): tokenizer = load_tokenizer(cfg) assert "Fast" not in tokenizer.__class__.__name__ + def test_special_tokens_modules_to_save(self): + # setting special_tokens to new token + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "adapter": "lora", + "special_tokens": {"bos_token": "[INST]"}, + } + ) + with pytest.raises( + ValueError, + match=r".*Please set lora_modules_to_save*", + ): + load_tokenizer(cfg) + + # setting special_tokens but not changing from default + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "adapter": "lora", + "special_tokens": {"bos_token": ""}, + } + ) + load_tokenizer(cfg) + + # non-adapter setting special_tokens + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "special_tokens": {"bos_token": "[INST]"}, + } + ) + load_tokenizer(cfg) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_validation.py b/tests/test_validation.py index fabc23da33..12997b023b 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -682,6 +682,43 @@ def test_warmup_step_no_conflict(self): validate_config(cfg) + def test_add_tokens_adapter(self): + cfg = DictDefault( + {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} + ) + + with pytest.raises( + ValueError, + match=r".*lora_modules_to_save not properly set yet adding new tokens*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens"], + } + ) + + with pytest.raises( + ValueError, + match=r".*lora_modules_to_save not properly set yet adding new tokens*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens", "lm_head"], + } + ) + + validate_config(cfg) + class ValidationWandbTest(ValidationTest): """