From 5ada140ff02edb18dfbf9b2cdc08c13203cb0e7d Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Thu, 14 Dec 2023 10:03:59 -0800 Subject: [PATCH 01/41] Fix prompt assembly for llama (#952) * start at index 0 * add test to check for missing turns * apply black * Update test_prompt_tokenizers.py * Update src/axolotl/monkeypatch/fastchat_conversation_turns.py Co-authored-by: Motoki Wu * fix linting * apply black * add more tests for llama/sharegpt * make logic clearer --------- Co-authored-by: Motoki Wu --- .../fastchat_conversation_turns.py | 17 +++-- tests/test_prompt_tokenizers.py | 70 +++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index 19313fb7e2..e1065a950f 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -83,14 +83,21 @@ def get_turns( # pylint: disable=too-many-return-statements yield role + ":", "" return if self.sep_style == SeparatorStyle.LLAMA2: - seps = [self.sep, self.sep2] if self.system_message: + if self.messages: + # For llama, the system message is incorporated into the first human instruction + first_role, first_msg = self.messages[0] + if first_role == self.roles[0]: + system_prompt += first_msg + self.messages.pop(0) yield "", system_prompt - else: - yield "", "[INST] " - for i, (role, message) in enumerate(self.messages[1:]): + for i, (role, message) in enumerate(self.messages): if message: - yield role + " ", message + seps[i % 2] + if (i % 2 == 0 and not self.system_message) or ( + i % 2 != 0 and self.system_message + ): + role = " " + role + yield role + " ", message else: yield role, "" return diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 0635bd718b..6e57ffb370 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -114,6 +114,76 @@ def test_sharegpt_warnings_turns(self): in self._caplog.records[0].message ) + def test_sharegpt_llama(self): + "Make sure the sharegpt/llama is tokenized and formatted correctly." + prompter = ShareGPTPrompterV2(conversation="llama-2") + strat = ShareGPTPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + + def tokenize(conv): + return strat.tokenize_prompt(conv)["input_ids"] + + def decode(ids): + return strat.tokenizer.decode(ids) + + # Multi-turn conversations + multi_turn_conv = { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + } + # fmt: off + mt_ids = tokenize(multi_turn_conv) + assert decode(mt_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum [INST] 123 [/INST] sit' + assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + + # Single-turn conversations + single_turn_conv = { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + } + + st_ids = tokenize(single_turn_conv) + assert decode(st_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum' + assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, single-turn + no_sys_conv = { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + } + + ns_ids = tokenize(no_sys_conv) + assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' + assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, multi-turn + no_sys_mt_conv = { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + } + ns_mt_ids = tokenize(no_sys_mt_conv) + assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' + assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + # fmt: on + def test_sharegpt_changes_roles(self): conversation = { "roles": ["USER", "CHARACTER"], From f28e75513bff30c7a622fde5c26530cf0332ff71 Mon Sep 17 00:00:00 2001 From: dumpmemory <64742282+dumpmemory@users.noreply.github.com> Date: Sat, 16 Dec 2023 10:03:17 +0800 Subject: [PATCH 02/41] update transformers to fix checkpoint saving (#963) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bbee7cf45b..a2899fa6fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ auto-gptq==0.5.1 packaging peft==0.6.0 -transformers @ git+https://github.com/huggingface/transformers.git@e5079b0b2abcef11ecbdae60ba4a6636c57b725d +transformers @ git+https://github.com/huggingface/transformers.git@ebfdb9ca62205279d5019ef1403877461b3b2da4 tokenizers==0.15.0 bitsandbytes>=0.41.1 accelerate==0.24.1 From 80ec7af358d363f2f4cd893312e1ccaf4a9d60ba Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 16 Dec 2023 18:31:25 -0500 Subject: [PATCH 03/41] update to latest nccl in docker image (#965) --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 6eea7322ce..790076687e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -10,7 +10,7 @@ ARG PYTORCH_VERSION="2.0.1" ENV PYTORCH_VERSION=$PYTORCH_VERSION RUN apt-get update && \ - apt-get install -y vim curl + apt-get install -y vim curl nano libnccl2 libnccl-dev WORKDIR /workspace From 85de004dd47ea5e4e35f8b8a988000a7bda868c8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 16 Dec 2023 19:12:01 -0500 Subject: [PATCH 04/41] fix for build for nccl in dockerfile (#970) --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 790076687e..41915de83d 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -10,7 +10,7 @@ ARG PYTORCH_VERSION="2.0.1" ENV PYTORCH_VERSION=$PYTORCH_VERSION RUN apt-get update && \ - apt-get install -y vim curl nano libnccl2 libnccl-dev + apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev WORKDIR /workspace From 13e938149dd6a1eec317c4cdde74d3ce991a8b86 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 17 Dec 2023 18:48:28 +0900 Subject: [PATCH 05/41] fix: add lr scheduler kwargs to Trainer (#972) --- src/axolotl/core/trainer_builder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ccd9d37c0d..cc162d210a 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -692,6 +692,9 @@ def build(self, total_num_steps): and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine" ) + training_arguments_kwargs["lr_scheduler_kwargs"] = ( + self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} + ) training_arguments_kwargs["weight_decay"] = ( self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 ) From d25c34caa6f324cdd7521138f4515e6ed7dc258d Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Sun, 17 Dec 2023 23:51:25 +0900 Subject: [PATCH 06/41] Update README.md (#966) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c03eec54b4..f0e399ace8 100644 --- a/README.md +++ b/README.md @@ -981,7 +981,7 @@ wandb_log_model: ##### Special Tokens -It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocubulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: +It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: ```yml special_tokens: From 161bcb6517ef38fee60ce5b243b6ac68d8bbdef7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 21 Dec 2023 09:38:20 -0500 Subject: [PATCH 07/41] Dockerfile torch fix (#987) * add torch to requirements.txt at build time to force version to stick * fix xformers check * better handling of xformers based on installed torch version * fix for ci w/o torch --- .github/workflows/base.yml | 2 +- .github/workflows/main.yml | 4 ++-- docker/Dockerfile | 1 - setup.py | 15 +++++++++------ 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 5f08854842..1dbff114ed 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -28,7 +28,7 @@ jobs: - cuda: "118" cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" steps: - name: Checkout diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9514208b1c..87b308362b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,7 +27,7 @@ jobs: - cuda: 118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 axolotl_extras: runs-on: [self-hosted, gpu, docker] steps: @@ -80,7 +80,7 @@ jobs: - cuda: 118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.1.0 + pytorch: 2.1.1 axolotl_extras: runs-on: [self-hosted, gpu, docker] steps: diff --git a/docker/Dockerfile b/docker/Dockerfile index 41915de83d..81a08bc8b7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -19,7 +19,6 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets -RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \ else \ diff --git a/setup.py b/setup.py index 42fd22df11..fe4d2cfad8 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,7 @@ """setup.py for axolotl""" +from importlib.metadata import PackageNotFoundError, version + from setuptools import find_packages, setup @@ -22,12 +24,13 @@ def parse_requirements(): # Handle standard packages _install_requires.append(line) - # TODO(wing) remove once xformers release supports torch 2.1.0 - if "torch==2.1.0" in _install_requires: - _install_requires.pop(_install_requires.index("xformers>=0.0.22")) - _install_requires.append( - "xformers @ git+https://github.com/facebookresearch/xformers.git@main" - ) + try: + torch_version = version("torch") + if torch_version.startswith("2.1.1"): + _install_requires.pop(_install_requires.index("xformers==0.0.22")) + _install_requires.append("xformers==0.0.23") + except PackageNotFoundError: + pass return _install_requires, _dependency_links From 7bbaac98f75e808b7903954a8ff32ae11c35feea Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Thu, 21 Dec 2023 08:00:55 -0800 Subject: [PATCH 08/41] fix mistral prompt assembly (#982) * fix mistral prompts * fix spacing * remove elif --- .../fastchat_conversation_turns.py | 24 +++- tests/test_prompt_tokenizers.py | 131 ++++++++++++------ 2 files changed, 108 insertions(+), 47 deletions(-) diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index e1065a950f..068261da36 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -82,7 +82,7 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + ":", "" return - if self.sep_style == SeparatorStyle.LLAMA2: + if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": if self.system_message: if self.messages: # For llama, the system message is incorporated into the first human instruction @@ -101,6 +101,28 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return + if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral": + contains_sys_msg = False + if self.system_message: + contains_sys_msg = True + if self.messages: + # There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline + first_role, first_msg = self.messages[0] + if first_role == self.roles[0]: + system_prompt = self.system_template.format( + system_message=" " + self.system_message + ) + system_prompt += first_msg + self.messages.pop(0) + yield "", system_prompt + for i, (role, message) in enumerate(self.messages): + if message and i == 0 and not contains_sys_msg: + yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a ` [INST]` at the beginning of the first instruction. + elif message: + yield role + " ", message + else: + yield role, "" + return if self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 6e57ffb370..cea39d0adf 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -2,6 +2,7 @@ import json import logging import unittest +from copy import deepcopy from pathlib import Path from typing import Optional @@ -25,6 +26,50 @@ LOG = logging.getLogger("axolotl") +test_data = { + "multi_turn_sys": { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + }, + "single_turn_sys": { + "conversations": [ + {"from": "system", "value": "lorem"}, + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + }, + "single_turn_no_sys": { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + ] + }, + "multi_turn_no_sys": { + "conversations": [ + {"from": "human", "value": "abc"}, + {"from": "gpt", "value": "ipsum"}, + {"from": "human", "value": "123"}, + {"from": "gpt", "value": "sit"}, + ] + }, +} + + +def prompt_strat(conversation, tokenizer): + "Helper function to create a prompt strategy for testing." + prompter = ShareGPTPrompterV2(conversation=conversation) + return ShareGPTPromptTokenizingStrategy( + prompter, + tokenizer, + False, + 2048, + ) + class TestPromptTokenizationStrategies(unittest.TestCase): """ @@ -116,74 +161,68 @@ def test_sharegpt_warnings_turns(self): def test_sharegpt_llama(self): "Make sure the sharegpt/llama is tokenized and formatted correctly." - prompter = ShareGPTPrompterV2(conversation="llama-2") - strat = ShareGPTPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) + strat = prompt_strat("llama-2", self.tokenizer) def tokenize(conv): - return strat.tokenize_prompt(conv)["input_ids"] + return strat.tokenize_prompt(deepcopy(conv))["input_ids"] def decode(ids): return strat.tokenizer.decode(ids) - # Multi-turn conversations - multi_turn_conv = { - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "123"}, - {"from": "gpt", "value": "sit"}, - ] - } # fmt: off - mt_ids = tokenize(multi_turn_conv) + # System message, multi-turn conversations + mt_ids = tokenize(test_data['multi_turn_sys']) assert decode(mt_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum [INST] 123 [/INST] sit' assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] - # Single-turn conversations - single_turn_conv = { - "conversations": [ - {"from": "system", "value": "lorem"}, - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - ] - } - - st_ids = tokenize(single_turn_conv) + # System message, single-turn conversations + st_ids = tokenize(test_data['single_turn_sys']) assert decode(st_ids) == ' [INST] <>\nlorem\n<>\n\nabc [/INST] ipsum' assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] # No system message, single-turn - no_sys_conv = { - "conversations": [ - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - ] - } - - ns_ids = tokenize(no_sys_conv) + ns_ids = tokenize(test_data['single_turn_no_sys']) assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] # No system message, multi-turn - no_sys_mt_conv = { - "conversations": [ - {"from": "human", "value": "abc"}, - {"from": "gpt", "value": "ipsum"}, - {"from": "human", "value": "123"}, - {"from": "gpt", "value": "sit"}, - ] - } - ns_mt_ids = tokenize(no_sys_mt_conv) + ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] # fmt: on + def test_sharegpt_mistral(self): + "Make sure the sharegpt/mistral is tokenized and formatted correctly." + strat = prompt_strat("mistral", self.tokenizer) + + def tokenize(conv): + return strat.tokenize_prompt(deepcopy(conv))["input_ids"] + + def decode(ids): + return strat.tokenizer.decode(ids) + + # fmt: off + # System message, multi-turn conversations + mt_ids = tokenize(test_data['multi_turn_sys']) + assert decode(mt_ids) == ' [INST] lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit' + assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + + # System message, single-turn conversations + st_ids = tokenize(test_data['single_turn_sys']) + assert decode(st_ids) == ' [INST] lorem\nabc [/INST] ipsum' + assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, single-turn + ns_ids = tokenize(test_data['single_turn_no_sys']) + assert decode(ns_ids) == ' [INST] abc [/INST] ipsum' + assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2] + + # No system message, multi-turn + ns_mt_ids = tokenize(test_data['multi_turn_no_sys']) + assert decode(ns_mt_ids) == ' [INST] abc [/INST] ipsum [INST] 123 [/INST] sit' + assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] + # fmt: on + def test_sharegpt_changes_roles(self): conversation = { "roles": ["USER", "CHARACTER"], From 62ba1609b6ca37ca5f344894ced0d45bcfb2a9d4 Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Thu, 21 Dec 2023 08:54:08 -0800 Subject: [PATCH 09/41] bump actions versions --- .github/workflows/main.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 87b308362b..bdab96ac21 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -32,21 +32,21 @@ jobs: runs-on: [self-hosted, gpu, docker] steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Docker metadata id: metadata - uses: docker/metadata-action@v3 + uses: docker/metadata-action@v5 with: images: winglian/axolotl - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 + uses: docker/setup-buildx-action@v3 - name: Build - uses: docker/build-push-action@v4 + uses: docker/build-push-action@v5 with: context: . build-args: | @@ -85,21 +85,21 @@ jobs: runs-on: [self-hosted, gpu, docker] steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Docker metadata id: metadata - uses: docker/metadata-action@v3 + uses: docker/metadata-action@v5 with: images: winglian/axolotl-runpod - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 - name: Build - uses: docker/build-push-action@v4 + uses: docker/build-push-action@v5 with: context: . build-args: | From 1ffa3866f2500fce827bc60f3907a2103ba3ac54 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 22 Dec 2023 21:49:07 +0900 Subject: [PATCH 10/41] Feat: Warns to add to modules_to_save when adding tokens or switching special_tokens (#787) * Feat: Auto add to modules_to_save when adding tokens * fix: swap to error instead of warning * feat: add check when special_tokens differ and add test --- src/axolotl/utils/config.py | 14 ++++++++++++++ src/axolotl/utils/models.py | 17 +++++++++++++++++ tests/test_tokenizers.py | 36 ++++++++++++++++++++++++++++++++++++ tests/test_validation.py | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 1b4ce92465..d9e56b95a6 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -448,6 +448,20 @@ def validate_config(cfg): if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0: raise ValueError("neftune_noise_alpha must be > 0.0") + if ( + cfg.adapter + and cfg.tokens + and ( + not cfg.lora_modules_to_save + or not all( + x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"] + ) + ) + ): + raise ValueError( + "lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`." + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 022229af85..8cb9e8426a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -136,6 +136,23 @@ def load_tokenizer(cfg): if cfg.special_tokens: for k, val in cfg.special_tokens.items(): + # check if new special token is not already in tokenizer and + # is adapter training to make sure lora_modules_to_save is set + if ( + (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) + and cfg.adapter + and ( + not cfg.lora_modules_to_save + or not all( + x in cfg.lora_modules_to_save + for x in ["embed_tokens", "lm_head"] + ) + ) + ): + raise ValueError( + "Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens." + ) + tokenizer.add_special_tokens( {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} ) diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index 5c83391942..bfe4f06af9 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -3,6 +3,8 @@ """ import unittest +import pytest + from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_tokenizer @@ -31,6 +33,40 @@ def test_dont_use_fast(self): tokenizer = load_tokenizer(cfg) assert "Fast" not in tokenizer.__class__.__name__ + def test_special_tokens_modules_to_save(self): + # setting special_tokens to new token + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "adapter": "lora", + "special_tokens": {"bos_token": "[INST]"}, + } + ) + with pytest.raises( + ValueError, + match=r".*Please set lora_modules_to_save*", + ): + load_tokenizer(cfg) + + # setting special_tokens but not changing from default + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "adapter": "lora", + "special_tokens": {"bos_token": ""}, + } + ) + load_tokenizer(cfg) + + # non-adapter setting special_tokens + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "special_tokens": {"bos_token": "[INST]"}, + } + ) + load_tokenizer(cfg) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_validation.py b/tests/test_validation.py index fabc23da33..12997b023b 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -682,6 +682,43 @@ def test_warmup_step_no_conflict(self): validate_config(cfg) + def test_add_tokens_adapter(self): + cfg = DictDefault( + {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} + ) + + with pytest.raises( + ValueError, + match=r".*lora_modules_to_save not properly set yet adding new tokens*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens"], + } + ) + + with pytest.raises( + ValueError, + match=r".*lora_modules_to_save not properly set yet adding new tokens*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens", "lm_head"], + } + ) + + validate_config(cfg) + class ValidationWandbTest(ValidationTest): """ From 2e61dc31802f4762622bb68849d8b8c6d87dfebe Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Fri, 22 Dec 2023 06:37:20 -0800 Subject: [PATCH 11/41] Add tests to Docker (#993) --- .github/workflows/main.yml | 21 +++++++++++++++++---- docker/Dockerfile | 3 +++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index bdab96ac21..3eb97e55d4 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -38,27 +38,40 @@ jobs: uses: docker/metadata-action@v5 with: images: winglian/axolotl + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - name: Build + # guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/ + - name: Build and export to Docker uses: docker/build-push-action@v5 with: context: . + load: true build-args: | BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} CUDA=${{ matrix.cuda }} PYTORCH_VERSION=${{ matrix.pytorch }} file: ./docker/Dockerfile - push: ${{ github.event_name != 'pull_request' }} tags: | ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} labels: ${{ steps.metadata.outputs.labels }} + - name: Unit Tests + run: | + docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ + - name: Push to Docker Hub + if: github.event_name != 'pull_request' + run: | + docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} + latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} + if [ -n "$latest_tag" ]; then + docker push "$latest_tag" + fi + build-axolotl-runpod: needs: build-axolotl if: github.repository_owner == 'OpenAccess-AI-Collective' diff --git a/docker/Dockerfile b/docker/Dockerfile index 81a08bc8b7..f8e0528562 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -25,6 +25,9 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[deepspeed,flash-attn]; \ fi +# So we can test the Docker image +RUN pip install pytest + # fix so that git fetch/pull from remote works RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ git config --get remote.origin.fetch From 93ebec1ac51a1c485b9dfd57a6b5e4f6c870ad49 Mon Sep 17 00:00:00 2001 From: mhenrichsen Date: Fri, 22 Dec 2023 16:18:16 +0100 Subject: [PATCH 12/41] change val size (#992) --- examples/mistral/qlora.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 64b26f4fa3..35c79ebf4e 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -11,7 +11,7 @@ datasets: - path: mhenrichsen/alpaca_2k_test type: alpaca dataset_prepared_path: last_run_prepared -val_set_size: 0.05 +val_set_size: 0.1 output_dir: ./qlora-out adapter: qlora From 7d4185ffcbb424ce39d2fa5430753bb56c371f88 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 23 Dec 2023 00:29:36 +0900 Subject: [PATCH 13/41] chore: Update transformers to latest (#986) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a2899fa6fc..c1c1cbc132 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ auto-gptq==0.5.1 packaging peft==0.6.0 -transformers @ git+https://github.com/huggingface/transformers.git@ebfdb9ca62205279d5019ef1403877461b3b2da4 +transformers==4.36.2 tokenizers==0.15.0 bitsandbytes>=0.41.1 accelerate==0.24.1 From 37820f65403b5212bb52f46eed8a455363f21544 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 22 Dec 2023 11:08:22 -0500 Subject: [PATCH 14/41] support for cuda 12.1 (#989) --- .github/workflows/base.yml | 5 +++++ .github/workflows/main.yml | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 1dbff114ed..6b90d1b501 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -30,6 +30,11 @@ jobs: python_version: "3.10" pytorch: 2.1.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" + - cuda: "121" + cuda_version: 12.1.0 + python_version: "3.10" + pytorch: 2.1.1 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" steps: - name: Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3eb97e55d4..2f0b074501 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,6 +29,11 @@ jobs: python_version: "3.10" pytorch: 2.1.1 axolotl_extras: + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.10" + pytorch: 2.1.1 + axolotl_extras: runs-on: [self-hosted, gpu, docker] steps: - name: Checkout @@ -95,6 +100,11 @@ jobs: python_version: "3.10" pytorch: 2.1.1 axolotl_extras: + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.10" + pytorch: 2.1.1 + axolotl_extras: runs-on: [self-hosted, gpu, docker] steps: - name: Checkout From 628b754824008f2d7c1aad079925a1d8e8cf9f48 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 22 Dec 2023 12:57:02 -0500 Subject: [PATCH 15/41] set output_router_logits for mixtral config: (#995) --- examples/mistral/mixtral.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 15df88f967..11c842d4ee 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -23,6 +23,9 @@ unfrozen_parameters: # - model.layers.3[0-9]+.block_sparse_moe.gate.* # - model.layers.3[0-9]+.block_sparse_moe.experts.* +model_config: + output_router_logits: true + adapter: qlora lora_model_dir: From 6ef46f8dcac84825c39fcea57917260abd33e9ac Mon Sep 17 00:00:00 2001 From: Evan Griffiths <56087052+evangriffiths@users.noreply.github.com> Date: Mon, 25 Dec 2023 18:29:55 +0000 Subject: [PATCH 16/41] Add an example config for finetuning a 34B model on a 24GB GPU (#1000) * Add an example config for finetuning a 34B model on a 24GB GPU * Remore wandb project --- examples/yi-34B-chat/README.md | 5 +++ examples/yi-34B-chat/qlora.yml | 76 ++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 examples/yi-34B-chat/README.md create mode 100644 examples/yi-34B-chat/qlora.yml diff --git a/examples/yi-34B-chat/README.md b/examples/yi-34B-chat/README.md new file mode 100644 index 0000000000..07078850fb --- /dev/null +++ b/examples/yi-34B-chat/README.md @@ -0,0 +1,5 @@ +# Overview + +This is an example of a Yi-34B-Chat configuration. It demonstrates that it is possible to finetune a 34B model on a GPU with 24GB of VRAM. + +Tested on an RTX 4090 with `python -m axolotl.cli.train examples/mistral/qlora.yml`, a single epoch of finetuning on the alpaca dataset using qlora runs in 47 mins, using 97% of available memory. diff --git a/examples/yi-34B-chat/qlora.yml b/examples/yi-34B-chat/qlora.yml new file mode 100644 index 0000000000..0c1a4b7889 --- /dev/null +++ b/examples/yi-34B-chat/qlora.yml @@ -0,0 +1,76 @@ +base_model: 01-ai/Yi-34B-Chat +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_mistral_derived_model: false +is_llama_derived_model: true +load_in_8bit: false +load_in_4bit: true +strict: false +sequence_len: 1024 +bf16: true +fp16: false +tf32: false +flash_attention: true +special_tokens: + bos_token: "<|startoftext|>" + eos_token: "<|endoftext|>" + unk_token: "" + +# Data +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +warmup_steps: 10 + +# Iterations +num_epochs: 1 + +# Evaluation +val_set_size: 0.1 +evals_per_epoch: 5 +eval_table_size: +eval_table_max_new_tokens: 128 +eval_sample_packing: false +eval_batch_size: 1 + +# LoRA +output_dir: ./qlora-out +adapter: qlora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: +lora_target_modules: + +# Sampling +sample_packing: false +pad_to_sequence_len: false + +# Batching +gradient_accumulation_steps: 4 +micro_batch_size: 1 +gradient_checkpointing: true + +# wandb +wandb_project: + +# Optimizer +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +# Misc +train_on_inputs: false +group_by_length: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +debug: +deepspeed: +weight_decay: 0 +fsdp: +fsdp_config: From db9094df0f271f857f73e9d64efe02a67013ee04 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 27 Dec 2023 23:25:20 +0100 Subject: [PATCH 17/41] FEAT: add tagging support to axolotl (#1004) * add tagging support to axolotl * chore: lint * fix method w self --------- Co-authored-by: Wing Lian --- src/axolotl/core/trainer_builder.py | 36 ++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index cc162d210a..c74114a176 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -9,7 +9,7 @@ import sys from abc import abstractmethod from dataclasses import dataclass, field -from functools import partial +from functools import partial, wraps from pathlib import Path from typing import Optional @@ -120,6 +120,7 @@ class AxolotlTrainer(Trainer): """ args = None # type: AxolotlTrainingArguments + tag_names = ["axolotl"] def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs): self.num_epochs = num_epochs @@ -290,12 +291,41 @@ def compute_loss(self, model, inputs, return_outputs=False): # return (loss, outputs) if return_outputs else loss return super().compute_loss(model, inputs, return_outputs=return_outputs) + def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None): + if isinstance(tag_names, str): + tag_names = [tag_names] + + if kwargs is not None: + if "tags" not in kwargs: + kwargs["tags"] = tag_names + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].extend(tag_names) + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + tag_names.append(kwargs["tags"]) + kwargs["tags"] = tag_names + + return kwargs + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, *args, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tags when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = self._sanitize_kwargs_for_tagging( + tag_names=self.tag_names, kwargs=kwargs + ) + + return super().push_to_hub(*args, **kwargs) + class AxolotlMambaTrainer(AxolotlTrainer): """ Mamba specific trainer to handle loss calculation """ + tag_names = ["axolotl", "mamba"] + def compute_loss( self, model, @@ -322,6 +352,8 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer): Trainer subclass that uses the OneCycleLR scheduler """ + tag_names = ["axolotl", "onecycle"] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.lr_scheduler = None @@ -351,6 +383,8 @@ class ReLoRATrainer(AxolotlTrainer): Trainer subclass that uses the OneCycleLR scheduler """ + tag_names = ["axolotl", "relora"] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.lr_scheduler = None From 384b817dc046db148ff7d07be53bdfd03950a5f4 Mon Sep 17 00:00:00 2001 From: Kevin Sydney <139094933+kmsydney@users.noreply.github.com> Date: Wed, 27 Dec 2023 16:11:55 -0800 Subject: [PATCH 18/41] Set eval_sample_packing to false in mistral config.yaml (#1003) Without eval_sampling_packing set to false, ValueError occurs with eval dataset split is too small for sample_packing. --- examples/mistral/config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index 1c37b05c13..ea62e9ebfe 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -17,6 +17,7 @@ output_dir: ./out sequence_len: 8192 sample_packing: true pad_to_sequence_len: true +eval_sample_packing: false wandb_project: wandb_entity: From 85dd4d525b9e65740ccb48d3c3897d35c9ae5265 Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Wed, 27 Dec 2023 19:25:33 -0800 Subject: [PATCH 19/41] add config to model card (#1005) * add config to model card * rm space * apply black formatting * apply black formatting * fix formatting * check for cfg attribute * add version * add version * put the config in a collapsible element * put the config in a collapsible element --- src/axolotl/train.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 169fc51272..4e5241e4c8 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -12,6 +12,7 @@ from accelerate.logging import get_logger from datasets import Dataset from optimum.bettertransformer import BetterTransformer +from pkg_resources import get_distribution # type: ignore from transformers.deepspeed import is_deepspeed_zero3_enabled from axolotl.common.cli import TrainerCliArgs @@ -115,6 +116,12 @@ def terminate_handler(_, __, model): badge_markdown = """[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)""" transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" + if getattr(cfg, "axolotl_config_path"): + raw_axolotl_cfg = Path(cfg.axolotl_config_path) + version = get_distribution("axolotl").version + if raw_axolotl_cfg.is_file(): + transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n
See axolotl config\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n

\n" + LOG.info("Starting trainer...") if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length") From 70b46ca4f45b9cec5ded9563a68bb08607b56a3c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 27 Dec 2023 23:07:27 -0600 Subject: [PATCH 20/41] remove landmark attn and xpos rope implementations (#1010) --- README.md | 5 - src/axolotl/cli/__init__.py | 16 - src/axolotl/core/trainer_builder.py | 22 +- .../monkeypatch/llama_landmark_attn.py | 1249 ----------------- .../xpos_rope_llama_monkey_patch.py | 94 -- src/axolotl/utils/models.py | 19 - 6 files changed, 1 insertion(+), 1404 deletions(-) delete mode 100644 src/axolotl/monkeypatch/llama_landmark_attn.py delete mode 100644 src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py diff --git a/README.md b/README.md index f0e399ace8..a4122772d1 100644 --- a/README.md +++ b/README.md @@ -798,11 +798,6 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation # Whether to use scaled-dot-product attention # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html sdp_attention: -# Landmark attention (only llama) -landmark_attention: -# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py -# LLaMA only -xpos_rope: # Resume from a specific checkpoint dir resume_from_checkpoint: diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 8ca4f7fe55..e6537ad052 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -103,14 +103,6 @@ def do_inference( importlib.import_module("axolotl.prompters"), prompter ) - if cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id - - set_model_mem_id(model, tokenizer) - model.set_mem_cache_args( - max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None - ) - model = model.to(cfg.device) while True: @@ -176,14 +168,6 @@ def do_inference_gradio( importlib.import_module("axolotl.prompters"), prompter ) - if cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id - - set_model_mem_id(model, tokenizer) - model.set_mem_cache_args( - max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None - ) - model = model.to(cfg.device) def generate(instruction): diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index c74114a176..fed26de464 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -9,7 +9,7 @@ import sys from abc import abstractmethod from dataclasses import dataclass, field -from functools import partial, wraps +from functools import wraps from pathlib import Path from typing import Optional @@ -780,26 +780,6 @@ def build(self, total_num_steps): # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html data_collator_kwargs["pad_to_multiple_of"] = 64 - if self.cfg.is_llama_derived_model and self.cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import ( - add_mem_tokens, - get_mem_id, - set_model_mem_id, - ) - - set_model_mem_id(self.model, self.tokenizer) - - LOG.info("Adding landmark attention tokens to dataset") - - for dataset in [self.train_dataset, self.eval_dataset]: - dataset = dataset.map( - partial( - add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer) - ), - batched=False, - num_proc=32, - ) - trainer_cls = self._get_trainer_cls() trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py deleted file mode 100644 index 24a98305f3..0000000000 --- a/src/axolotl/monkeypatch/llama_landmark_attn.py +++ /dev/null @@ -1,1249 +0,0 @@ -# pylint: skip-file -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -PyTorch LLaMA model. -Taken from https://github.com/epfml/landmark-attention/blob/main/llama/llama_mem.py and modified. -""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss -from transformers import LlamaTokenizer -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import ( - LLAMA_INPUTS_DOCSTRING, - LLAMA_START_DOCSTRING, - LlamaMLP, - LlamaPreTrainedModel, - LlamaRMSNorm, - LlamaRotaryEmbedding, - _expand_mask, - _make_causal_mask, - rotate_half, -) -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) - -LOG = logging.getLogger("axolotl") - -_CONFIG_FOR_DOC = "LlamaConfig" - -MEM_TOKEN = "" # nosec - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - if q is None: - q_embed = None - else: - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LandmarkGroupedSoftmaxFunction(torch.autograd.Function): - """ - Landmark grouped softmax function. - """ - - # Note that forward, setup_context, and backward are @staticmethods - @staticmethod - def forward(ctx, x, dim, mem_cnt, resp_mem_idx): - new_shape = list(x.shape) - new_shape[dim] = mem_cnt # max_mem_cnt.item() - max_by_group = x.new_zeros((*new_shape,)) - max_by_group.scatter_reduce_( - src=x, index=resp_mem_idx, dim=dim, reduce="amax", include_self=False - ) - - maxes = torch.gather(max_by_group, dim, resp_mem_idx) - # x_exp = torch.exp(x - torch.where(torch.isinf(maxes), 0, maxes)) - x_exp = torch.exp((x - maxes).to(torch.float32)) - - cumsum_by_group = torch.zeros_like(max_by_group, dtype=x_exp.dtype) - - cumsum_by_group.scatter_add_( - dim, - resp_mem_idx, - x_exp, - ) - denom = torch.gather(cumsum_by_group, dim, resp_mem_idx) - - # probs = torch.where(denom < 0.5, 0, x_exp / denom) - probs = x_exp / denom - - ctx.mem_cnt = mem_cnt - ctx.dim = dim - ctx.save_for_backward(resp_mem_idx, probs) - - return probs - - @staticmethod - def backward(ctx, grad_probs): - mem_cnt = ctx.mem_cnt - dim = ctx.dim - resp_mem_idx, probs = ctx.saved_tensors - grad_x = grad_dim = grad_mem_cnt = grad_resp_mem_idx = None - - if ctx.needs_input_grad[0] or ctx.needs_input_grad[4]: - grad_pair = grad_probs * probs - - new_shape = list(probs.shape) - new_shape[dim] = mem_cnt # max_mem_cnt.item() - cumsum_by_group = grad_pair.new_zeros((*new_shape,)) - cumsum_by_group.scatter_add_(dim, resp_mem_idx, grad_pair) - - if ctx.needs_input_grad[0]: - grad_sum = torch.gather(cumsum_by_group, dim, resp_mem_idx) - grad_x = grad_pair - probs * grad_sum - assert not ctx.needs_input_grad[1] - assert not ctx.needs_input_grad[2] - assert not ctx.needs_input_grad[3] - - return grad_x, grad_dim, grad_mem_cnt, grad_resp_mem_idx - - -def landmark_grouped_softmax(x, dim, is_mem, last_section_mask): - last_and_rest_mask = last_section_mask # | mask - - full_access_mask = is_mem | last_and_rest_mask - - max_mem_cnt = 16 - mem_group_idx = torch.cumsum(is_mem, dim=dim) - mem_bucket_id = max_mem_cnt - 1 - resp_mem_idx = torch.where( - last_and_rest_mask, - max_mem_cnt - 1, - torch.where(is_mem, mem_bucket_id, mem_group_idx), - ) - probs = LandmarkGroupedSoftmaxFunction.apply(x, dim, max_mem_cnt, resp_mem_idx) - - new_shape = list(x.shape) - new_shape[dim] = max_mem_cnt - group_prob = probs.new_zeros((*new_shape,)) - group_prob.scatter_( - dim, torch.where(is_mem, mem_group_idx - 1, max_mem_cnt - 1), probs - ) - probs = probs.mul( - torch.where( - full_access_mask, - last_section_mask, - torch.gather(group_prob, dim, resp_mem_idx), - ) - ) - - return probs - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings - ) - - self.mem_freq = None - self.top_k = None - self.max_cache_size = None - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) - - def set_mem_cache_args(self, mem_freq, top_k, max_cache_size): - self.mem_freq = mem_freq - self.top_k = top_k - self.max_cache_size = max_cache_size - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - is_mem: Optional[torch.Tensor] = None, - last_section_mask: Optional[torch.Tensor] = None, - offload_cache_to_cpu: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = ( - self.q_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - key_states = ( - self.k_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - value_states = ( - self.v_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - if len(past_key_value) > 2: - kv_seq_len += past_key_value[3].shape[2] * past_key_value[3].shape[3] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - key_states_before_pos = key_states - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - - attn_prefix = None - if past_key_value is not None: - # reuse k, v, self_attention - if self.mem_freq is None: - cache_len = past_key_value[0].shape[2] - if self.max_cache_size is not None: - cache_len = min(cache_len, self.max_cache_size) - if is_mem is not None: - is_mem = torch.cat( - (is_mem.new_zeros((1, 1, q_len, cache_len)), is_mem), dim=-1 - ) - last_section_mask = torch.cat( - ( - last_section_mask.new_ones((1, 1, q_len, cache_len)), - last_section_mask, - ), - dim=-1, - ) - - past_key_states = torch.cat([past_key_value[0], key_states], dim=2) - past_value_states = torch.cat([past_key_value[1], value_states], dim=2) - key_states = past_key_states[:, :, -(q_len + cache_len) :] - value_states = past_value_states[:, :, -(q_len + cache_len) :] - expected_att_size = (bsz, self.num_heads, q_len, cache_len + q_len) - else: - orig_value_states = value_states - - incomplete_len = past_key_value[0].shape[2] % (self.mem_freq + 1) - full_len = past_key_value[0].shape[2] - incomplete_len - past_key_mem, past_key_incomplete = torch.split( - past_key_value[0], (full_len, incomplete_len), dim=2 - ) - past_value_mem, past_value_incomplete = torch.split( - past_key_value[1], (full_len, incomplete_len), dim=2 - ) - - if offload_cache_to_cpu: - past_key_value = ( - past_key_incomplete, - past_value_incomplete, - *past_key_value[2:], - ) - - if incomplete_len > 0: - assert q_len + incomplete_len <= (self.mem_freq + 1) - is_mem = torch.cat( - (is_mem.new_zeros((1, 1, q_len, incomplete_len)), is_mem), dim=-1 - ) - last_section_mask = torch.cat( - ( - last_section_mask.new_ones((1, 1, q_len, incomplete_len)), - last_section_mask, - ), - dim=-1, - ) - - if len(past_key_value) > 2: - full_len += past_key_value[3].shape[2] * past_key_value[3].shape[3] - past_key_incomplete_pos = torch.arange( - full_len, - full_len + incomplete_len, - dtype=torch.long, - device=position_ids.device, - ).unsqueeze(0) - _, past_key_incomplete = apply_rotary_pos_emb( - None, past_key_incomplete, cos, sin, past_key_incomplete_pos - ) - key_states = torch.cat((past_key_incomplete, key_states), dim=2) - value_states = torch.cat((past_value_incomplete, value_states), dim=2) - - past_key_mem = past_key_mem.view( - bsz, self.num_heads, -1, self.mem_freq + 1, self.head_dim - ) - past_value_mem = past_value_mem.view( - bsz, self.num_heads, -1, self.mem_freq + 1, self.head_dim - ) - - if len(past_key_value) > 2: - mem_key_nopos = torch.cat( - ( - past_key_value[2], - past_key_mem.select(dim=3, index=self.mem_freq), - ), - dim=2, - ) - past_key_mem_offload = past_key_value[3] - past_key_mem = torch.cat( - ( - past_key_mem_offload, - past_key_mem.to(past_key_mem_offload.device), - ), - dim=2, - ) - past_value_mem = torch.cat( - ( - past_key_value[4], - past_value_mem.to(past_key_mem_offload.device), - ), - dim=2, - ) - else: - mem_key_nopos = past_key_mem.select(dim=3, index=self.mem_freq) - - num_mems = past_key_mem.shape[2] - top_k = min(self.top_k, num_mems) - prefix_len = full_len - (top_k + 1) * (self.mem_freq + 1) - mem_indices = torch.cat( - ( - position_ids.new_zeros((max(0, num_mems - top_k),)), - torch.arange( - 1, - top_k + 1, - device=query_states.device, - dtype=position_ids.dtype, - ), - ), - dim=0, - ) - mem_pos = (mem_indices * (self.mem_freq + 1) + self.mem_freq).unsqueeze( - 0 - ).expand(bsz, -1) + prefix_len - _, mem_key = apply_rotary_pos_emb( - None, mem_key_nopos, cos, sin, mem_pos - ) - mem_attn_weights = torch.matmul( - query_states, mem_key.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if offload_cache_to_cpu: - aggregate = "max_over_tokens" - else: - aggregate = None - if aggregate == "max_over_tokens": - token_retrievers = 1 - head_retrievers = self.num_heads - mem_attn_weights = torch.nn.functional.softmax( - mem_attn_weights, dim=-1 - ) - mem_attn_weights = mem_attn_weights.amax(dim=2, keepdim=True) - elif aggregate is None: - token_retrievers = q_len - head_retrievers = self.num_heads - else: - raise NotImplementedError() - - mem_selected_idx = ( - mem_attn_weights.topk(dim=-1, k=top_k)[1] - .sort(dim=-1)[0] - .view(bsz, head_retrievers, token_retrievers, top_k) - ) - - selected_indices = torch.arange( - 0, - top_k * (self.mem_freq + 1), - device=query_states.device, - dtype=position_ids.dtype, - ) - selected_indices = torch.where( - mem_selected_idx >= num_mems - top_k, self.mem_freq + 1, 0 - ).unsqueeze(-1) + selected_indices.view( - 1, 1, 1, top_k, self.mem_freq + 1 - ) - selected_indices = ( - selected_indices.view( - bsz, head_retrievers, token_retrievers, -1 - ).expand(bsz, self.num_heads, q_len, -1) - + prefix_len - ) - - mem_selected_idx = mem_selected_idx.to(past_key_mem.device) - - mem_selected_idx = mem_selected_idx.view( - bsz, self.num_heads, token_retrievers, top_k, 1, 1 - ).expand( - bsz, - self.num_heads, - token_retrievers, - top_k, - self.mem_freq + 1, - self.head_dim, - ) - selected_keys = past_key_mem.unsqueeze(2).expand( - bsz, - self.num_heads, - token_retrievers, - -1, - self.mem_freq + 1, - self.head_dim, - ) - selected_keys = selected_keys.take_along_dim( - mem_selected_idx, dim=3 - ).to(query_states.device) - selected_values = ( - past_value_mem.unsqueeze(2) - .expand( - bsz, - self.num_heads, - token_retrievers, - -1, - self.mem_freq + 1, - self.head_dim, - ) - .take_along_dim(mem_selected_idx, dim=3) - .to(query_states.device) - ) - - selected_keys = selected_keys.view( - bsz, self.num_heads, token_retrievers, -1, self.head_dim - ).expand(bsz, self.num_heads, q_len, -1, self.head_dim) - selected_keys = apply_rotary_pos_emb( - None, selected_keys.unsqueeze(1), cos, sin, selected_indices - )[1].squeeze(1) - selected_values = selected_values.view( - bsz, self.num_heads, token_retrievers, -1, self.head_dim - ).expand(bsz, self.num_heads, q_len, -1, self.head_dim) - attn_prefix = torch.matmul( - query_states.unsqueeze(3), selected_keys.transpose(3, 4) - ).squeeze(3) / math.sqrt(self.head_dim) - is_mem_prefix = ( - torch.cat( - (is_mem.new_zeros((self.mem_freq,)), is_mem.new_ones((1,))) - ) - .unsqueeze(0) - .repeat((top_k, 1)) - ) - is_mem_prefix = is_mem_prefix.view(1, 1, 1, -1).expand(1, 1, q_len, -1) - is_mem = torch.cat((is_mem_prefix, is_mem), dim=-1) - last_section_mask = torch.cat( - ( - last_section_mask.new_zeros( - (1, 1, q_len, top_k * (self.mem_freq + 1)) - ), - last_section_mask, - ), - dim=-1, - ) - expected_att_size = (bsz, self.num_heads, q_len, q_len + incomplete_len) - - past_key_states = torch.cat( - [past_key_value[0], key_states_before_pos], dim=2 - ) - past_value_states = torch.cat( - [past_key_value[1], orig_value_states], dim=2 - ) - - if offload_cache_to_cpu: - past_key_value = ( - ( - past_key_states, - past_value_states, - mem_key_nopos, - past_key_mem.to("cpu"), - past_value_mem.to("cpu"), - *past_key_value[5:], - ) - if use_cache - else None - ) - else: - past_key_value = ( - (past_key_states, past_value_states) if use_cache else None - ) - - else: - if self.mem_freq is None: - past_key_states = key_states - else: - past_key_states = key_states_before_pos - past_value_states = value_states - expected_att_size = (bsz, self.num_heads, q_len, kv_seq_len) - past_key_value = (past_key_states, past_value_states) if use_cache else None - - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - if attn_weights.size() != expected_att_size: - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask[..., -attn_weights.shape[-1] :] - attn_weights = torch.max( - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) - ) - if attn_prefix is not None: - attn_weights = torch.cat((attn_prefix, attn_weights), dim=-1) - # upcast attention to fp32 - if is_mem is None: - raise ValueError("Don't use this without landmarks") - - attn_weights = landmark_grouped_softmax( - attn_weights, - dim=-1, - is_mem=is_mem.expand(-1, self.num_heads, -1, -1), - last_section_mask=last_section_mask, - ).to(query_states.dtype) - - if attn_prefix is not None: - attn_prefix, attn_weights = torch.split( - attn_weights, - (attn_prefix.shape[-1], attn_weights.shape[-1] - attn_prefix.shape[-1]), - dim=-1, - ) - attn_output = torch.matmul(attn_weights, value_states) - if attn_prefix is not None: - attn_output += torch.matmul( - attn_prefix.unsqueeze(3), selected_values - ).squeeze(3) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - """ - Llama Decoder layer - """ - - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - def set_mem_cache_args(self, mem_freq, top_k, max_cache_size): - self.self_attn.set_mem_cache_args(mem_freq, top_k, max_cache_size) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - is_mem: Optional[torch.Tensor] = None, - last_section_mask: Optional[torch.Tensor] = None, - offload_cache_to_cpu: bool = False, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - is_mem=is_mem, - last_section_mask=last_section_mask, - offload_cache_to_cpu=offload_cache_to_cpu, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx - ) - self.layers = nn.ModuleList( - [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] - ) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.mem_id = None - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def set_mem_id(self, mem_id): - self.mem_id = mem_id - - def set_mem_cache_args(self, mem_freq, top_k, max_cache_size): - for layer in self.layers: - layer.set_mem_cache_args(mem_freq, top_k, max_cache_size) - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - offload_cache_to_cpu: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - is_mem = None - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - if self.mem_id is not None: - with torch.no_grad(): - is_mem = input_ids == self.mem_id - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - if self.mem_id is not None: - raise NotImplementedError - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - if is_mem is not None: - pass - # raise NotImplementedError - past_key_values_length = past_key_values[0][0].shape[2] - if len(past_key_values[0]) > 2: - past_key_values_length += ( - past_key_values[0][3].shape[2] * past_key_values[0][3].shape[3] - ) - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - - last_section_mask = None - if is_mem is not None: - is_mem = is_mem.unsqueeze(1).unsqueeze(2) - current_len = input_ids.shape[1] - mem_ids = torch.where( - attention_mask[..., -current_len:] < -1, - 0, - torch.cumsum(is_mem, -1) - is_mem.int(), - ) - last_section_mask = torch.amax(mem_ids, -1, keepdim=True) == mem_ids - attention_mask[..., -current_len:].masked_fill_( - last_section_mask & is_mem, - torch.tensor( - torch.finfo(inputs_embeds.dtype).min, device=inputs_embeds.device - ), - ) - last_section_mask.logical_and_(attention_mask[..., -current_len:] > -1) - is_mem = is_mem.logical_and(attention_mask[..., -current_len:] > -1) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - LOG.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - output_attentions, - None, - is_mem, - last_section_mask, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - is_mem=is_mem, - last_section_mask=last_section_mask, - offload_cache_to_cpu=offload_cache_to_cpu, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - """ - Llama model with a causal language modeling head. - """ - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - self.mem_id = None - self.mem_freq = None - self.top_k = None - self.max_seq_len = None - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - offload_cache_to_cpu: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - window_len = self.max_seq_len or input_ids.shape[1] - last_logits = None - for _, idx in enumerate(range(0, input_ids.shape[1], window_len)): - if idx >= 1: - if output_attentions or output_hidden_states: - raise NotImplementedError - if not use_cache: - raise NotImplementedError - outputs = self.model( - input_ids=input_ids[:, idx : idx + window_len], - attention_mask=attention_mask[ - :, : idx + window_len + attention_mask.shape[1] - input_ids.shape[1] - ] - if attention_mask is not None - else None, - position_ids=position_ids[:, idx : idx + window_len] - if position_ids is not None - else None, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds[:, idx : idx + window_len] - if inputs_embeds is not None - else None, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - offload_cache_to_cpu=offload_cache_to_cpu, - ) - past_key_values = outputs.past_key_values - if last_logits is not None: - last_logits = torch.cat((last_logits, outputs[0]), dim=-2) - last_logits = outputs[0] - - hidden_states = last_logits - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def set_mem_id(self, mem_id): - self.mem_id = mem_id - self.model.set_mem_id(mem_id) - - def set_mem_cache_args(self, max_seq_len, mem_freq, top_k, max_cache_size): - self.mem_freq = mem_freq - self.top_k = top_k - self.max_seq_len = max_seq_len - if self.max_seq_len is not None: - assert self.max_seq_len % (self.mem_freq + 1) == 0 - self.model.set_mem_cache_args(mem_freq, top_k, max_cache_size) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - total_len = input_ids.shape[1] - if past_key_values: - prev_len = input_ids.shape[1] - 1 - else: - prev_len = 0 - - position_ids = kwargs.get("position_ids", None) - - if self.mem_freq is not None: - if position_ids is not None: - raise NotImplementedError - # T = input_ids.shape[1] - - prev_incomplete_len = prev_len % self.mem_freq - prev_complete_len = prev_len - prev_incomplete_len - incomplete_len = total_len % self.mem_freq - new_full_len = total_len - prev_complete_len - incomplete_len - - prev_input, input_ids_with_mem, input_ids_without_mem = torch.split( - input_ids, (prev_complete_len, new_full_len, incomplete_len), dim=-1 - ) - - bsz, _ = input_ids.size() - input_ids_with_mem = input_ids_with_mem.view(bsz, -1, self.mem_freq) - input_ids_with_mem = torch.cat( - ( - input_ids_with_mem, - input_ids_with_mem.new_full( - (bsz, input_ids_with_mem.shape[1], 1), self.mem_id - ), - ), - dim=-1, - ).view(bsz, -1) - input_ids = torch.cat( - (prev_input, input_ids_with_mem, input_ids_without_mem), dim=-1 - ) - if attention_mask is not None: - attention_mask_with_mem, attention_mask_without_mem = torch.split( - attention_mask, - (prev_complete_len + new_full_len, incomplete_len), - dim=-1, - ) - attention_mask_with_mem = attention_mask_with_mem.view( - bsz, -1, self.mem_freq - ) - attention_mask_with_mem = torch.cat( - ( - attention_mask_with_mem, - attention_mask_with_mem.new_ones( - (bsz, attention_mask_with_mem.shape[1], 1) - ), - ), - dim=-1, - ).view(bsz, -1) - attention_mask = torch.cat( - (attention_mask_with_mem, attention_mask_without_mem), dim=-1 - ) - - input_ids = input_ids[:, prev_len:] - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids[:, -input_ids.shape[1] :].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if ( - inputs_embeds is not None - and past_key_values is None - and self.mem_freq is None - ): - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "offload_cache_to_cpu": kwargs.get("offload_cache_to_cpu"), - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) for past_state in layer_past - ), - ) - return reordered_past - - -def add_mem_tokens(example, mem_freq, mem_id): - ids = example["input_ids"] - ret = [] - prev_idx = 0 - for t_idx in range(mem_freq, len(ids), mem_freq): - ret.extend(ids[prev_idx:t_idx]) - ret.append(mem_id) - prev_idx = t_idx - ret.extend(ids[prev_idx:]) - # drop attention_mask - return {"input_ids": ret} - - -def patch_llama_with_landmark_attn(): - import transformers - - transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM - transformers.models.llama.modeling_llama.LlamaModel = LlamaModel - transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention - transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer - transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb - - -def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer): - mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN) - model.set_mem_id(mem_id) - - -def get_mem_id(tokenizer: LlamaTokenizer): - return tokenizer.convert_tokens_to_ids(MEM_TOKEN) diff --git a/src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py b/src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py deleted file mode 100644 index 4cbbd4f479..0000000000 --- a/src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py +++ /dev/null @@ -1,94 +0,0 @@ -# pylint: skip-file -""" -Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py -""" -import torch -import transformers -import transformers.models.llama.modeling_llama -from einops import rearrange - - -class XposRotaryEmbedding(torch.nn.Module): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scale_base=2048, - use_xpos=True, - ): - super().__init__() - self.max_seq_len_cached = max_position_embeddings - self.scale_base = scale_base - - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq) - freqs = torch.einsum("i , j -> i j", t, inv_freq) - freqs = torch.cat((freqs, freqs), dim=-1) - - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.register_buffer("freqs_cached", freqs, persistent=False) - - if not use_xpos: - self.register_buffer("scale", None) - self.register_buffer("scale_cached", torch.ones(1)) - return - - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) - power = (t - (self.max_seq_len_cached // 2)) / self.scale_base - scale_cached = scale ** rearrange(power, "n -> n 1") - scale_cached = torch.cat((scale_cached, scale_cached), dim=-1) - - self.register_buffer("scale", scale, persistent=False) - self.register_buffer("scale_cached", scale_cached, persistent=False) - - def forward( - self, - x, - seq_len, - ): - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device).type_as( - self.inv_freq - ) - freqs = torch.einsum("i , j -> i j", t, self.inv_freq) - freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype) - - self.register_buffer("freqs_cached", freqs) - - if self.scale is None: - self.register_buffer( - "scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype) - ) - - return self.freqs_cached.to(dtype=x.dtype), self.scale_cached - - power = (t - (seq_len // 2)) / self.scale_base - scale = self.scale ** rearrange(power, "n -> n 1") - scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype) - self.register_buffer("scale_cached", scale) - - return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype) - - -def rotate_half(x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None): - freqs = freqs[position_ids, :] - if scale.shape[-1] != 1: - scale = scale[position_ids, :] - - q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale) - k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale) - - return q_embed, k_embed - - -def replace_llama_rope_with_xpos_rope(): - transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding - transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8cb9e8426a..872d530abd 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -247,17 +247,6 @@ def load_model( LOG.info("patching with sdp attention") hijack_llama_sdp_attention() - elif cfg.is_llama_derived_model and cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import ( - MEM_TOKEN, - patch_llama_with_landmark_attn, - ) - - LOG.info("patching with landmark attention") - patch_llama_with_landmark_attn() - - # Note: This might overwrite previous additional_special_tokens - tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]}) if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing: from axolotl.monkeypatch.mistral_attn_hijack_flash import ( @@ -279,14 +268,6 @@ def load_model( LOG.info("patching with flash attention") replace_mixtral_attn_with_multipack_flash_attn() - if cfg.is_llama_derived_model and cfg.xpos_rope: - from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( - replace_llama_rope_with_xpos_rope, - ) - - LOG.info("patching with xpos rope") - replace_llama_rope_with_xpos_rope() - if ( cfg.is_llama_derived_model and (cfg.max_packed_sequence_len or cfg.sample_packing) From 76357dc5dae39ad51b3083b828b19eafc8e6b7bd Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Thu, 28 Dec 2023 18:00:02 -0800 Subject: [PATCH 21/41] Update README.md (#1012) --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a4122772d1..5f487642ea 100644 --- a/README.md +++ b/README.md @@ -990,9 +990,12 @@ tokens: # these are delimiters When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary. -### Inference +### Inference Playground -Pass the appropriate flag to the train command: +Axolotl allows you to load your model in an interactive terminal playground for quick experimentation. +The config file is the same config file used for training. + +Pass the appropriate flag to the inference command, depending upon what kind of model was trained: - Pretrained LORA: ```bash From dec66d7c53a2de6cf74911faf9c1ad1f7f0fff14 Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Thu, 28 Dec 2023 18:00:16 -0800 Subject: [PATCH 22/41] [Docs] Nit: Remind people to auth to wandb if they are going to use it (#1013) --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 5f487642ea..2e0a07855c 100644 --- a/README.md +++ b/README.md @@ -672,6 +672,7 @@ relora_warmup_steps: # Number of per-restart warmup steps relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings # wandb configuration if you're using it +# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`. wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb wandb_project: # Your wandb project name wandb_entity: # A wandb Team name if using a Team @@ -964,6 +965,8 @@ fsdp_config: ##### Weights & Biases Logging +Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`. + - wandb options ```yaml wandb_mode: From f6ecf14dd42fe265cacf3f67ca7dad7474b5b642 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 29 Dec 2023 18:15:30 +0900 Subject: [PATCH 23/41] feat: remove need to add load_in* during merge (#1017) --- README.md | 4 ++-- src/axolotl/cli/merge_lora.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2e0a07855c..bbed3e10df 100644 --- a/README.md +++ b/README.md @@ -996,7 +996,7 @@ When you include these tokens in your axolotl config, axolotl adds these tokens ### Inference Playground Axolotl allows you to load your model in an interactive terminal playground for quick experimentation. -The config file is the same config file used for training. +The config file is the same config file used for training. Pass the appropriate flag to the inference command, depending upon what kind of model was trained: @@ -1027,7 +1027,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi Add below flag to train command above ```bash -python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False +python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" ``` If you run out of CUDA memory, you can try to merge in system RAM with diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 0caee4c28b..4c810d5722 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -18,7 +18,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs): return_remaining_strings=True ) parsed_cli_args.merge_lora = True - parsed_cfg = load_cfg(config, merge_lora=True, **kwargs) + + parsed_cfg = load_cfg( + config, + merge_lora=True, + load_in_8bit=False, + load_in_4bit=False, + flash_attention=False, + **kwargs + ) do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) From 41353d2ea04db3478d2f6f9069b7d0adb1f30ae8 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 29 Dec 2023 18:16:26 +0900 Subject: [PATCH 24/41] feat: expose bnb kwargs (#1018) * feat: expose bnb kwargs * chore: added examples and link per suggestion * Uncomment defaults per suggestion for readability Co-authored-by: Hamel Husain --------- Co-authored-by: Hamel Husain --- README.md | 8 ++++++++ src/axolotl/utils/models.py | 19 +++++++++++++------ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index bbed3e10df..d15c4c001e 100644 --- a/README.md +++ b/README.md @@ -520,6 +520,14 @@ model_config: type: # linear | dynamic factor: # float +# optional overrides to the bnb 4bit quantization configuration +# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig +bnb_config_kwargs: + # These are default values + llm_int8_has_fp16_weight: false + bnb_4bit_quant_type: nf4 + bnb_4bit_use_double_quant: true + # Whether you are training a 4-bit GPTQ quantized model gptq: true diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 872d530abd..c2b3a758c7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -301,13 +301,20 @@ def load_model( **model_config.quantization_config ) if cfg.adapter == "qlora" and cfg.load_in_4bit: + bnb_config = { + "load_in_4bit": True, + "llm_int8_threshold": 6.0, + "llm_int8_has_fp16_weight": False, + "bnb_4bit_compute_dtype": cfg.torch_dtype, + "bnb_4bit_use_double_quant": True, + "bnb_4bit_quant_type": "nf4", + } + + if cfg.bnb_config_kwargs: + bnb_config.update(cfg.bnb_config_kwargs) + model_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=cfg.torch_dtype, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", + **bnb_config, ) # sample packing uses custom FA2 patch if cfg.flash_attention: From ba043a361e233d53fa82348ccaf6e3d7f13ee5c6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 29 Dec 2023 12:23:29 -0600 Subject: [PATCH 25/41] add ultrachat prompt strategies (#996) --- src/axolotl/prompt_strategies/sharegpt.py | 31 +++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index fbb44ccfae..c026889682 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -39,6 +39,23 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): return strategy +def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + conversation = ( + ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None + ) + strategy = UltrachatShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation=conversation, + ), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + if ds_cfg and "strict" in ds_cfg: + strategy.strict = ds_cfg["strict"] + return strategy + + def load_role(tokenizer, cfg): return SimpleRoleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2(), @@ -109,3 +126,17 @@ def get_conversation_thread(self, prompt): {"from": role_map[t["role"]], "value": t["text"]} for t in conversations ] return turns + + +class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): + """ + sharegpt strategy that remaps ultrachat data to sharegpt format + """ + + def get_conversation_thread(self, prompt): + conversations = prompt["messages"] + role_map = {"user": "human", "assistant": "gpt"} + turns = [ + {"from": role_map[t["role"]], "value": t["content"]} for t in conversations + ] + return turns From 4f4d638b84251226fe6f4b5160fa329c88a19ba5 Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Fri, 29 Dec 2023 10:52:12 -0800 Subject: [PATCH 26/41] [WandB] Push axolotl config to top level wandb files (#1014) --- src/axolotl/utils/callbacks.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 8599c0df0f..122cd92ede 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -4,6 +4,8 @@ import logging import os +from shutil import copyfile +from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Dict, List import evaluate @@ -561,10 +563,15 @@ def on_train_begin( ): if is_main_process(): try: - artifact = wandb.Artifact(name="axolotl-config", type="config") - artifact.add_file(local_path=self.axolotl_config_path) - wandb.run.log_artifact(artifact) - LOG.info("Axolotl config has been saved to WandB as an artifact.") + # sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later. + with NamedTemporaryFile( + mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" + ) as temp_file: + copyfile(self.axolotl_config_path, temp_file.name) + wandb.save(temp_file.name) + LOG.info( + "The Axolotl config has been saved to the WandB run under files." + ) except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to WandB: {err}") return control From f8ae59b0a89a8db5c7c956a91979971cefd8350c Mon Sep 17 00:00:00 2001 From: mhenrichsen Date: Fri, 29 Dec 2023 22:44:23 +0100 Subject: [PATCH 27/41] Adds chat templates (#1022) --- README.md | 3 +++ src/axolotl/utils/chat_templates.py | 29 +++++++++++++++++++++++++++++ src/axolotl/utils/models.py | 7 +++++++ 3 files changed, 39 insertions(+) create mode 100644 src/axolotl/utils/chat_templates.py diff --git a/README.md b/README.md index d15c4c001e..98b8a78239 100644 --- a/README.md +++ b/README.md @@ -589,6 +589,9 @@ datasets: # For `completion` datsets only, uses the provided field instead of `text` column field: +# Saves the desired chat template to the tokenizer_config.json for easier inferencing +# Currently supports chatml and inst (mistral/mixtral) +chat_template: chatml # Axolotl attempts to save the dataset as an arrow after packing the data together so # subsequent training attempts load faster, relative path dataset_prepared_path: data/last_run_prepared diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py new file mode 100644 index 0000000000..459da44007 --- /dev/null +++ b/src/axolotl/utils/chat_templates.py @@ -0,0 +1,29 @@ +""" +This module provides functionality for selecting chat templates based on user choices. +These templates are used for formatting messages in a conversation. +""" + + +def chat_templates(user_choice: str): + """ + Finds the correct chat_template for the tokenizer_config. + + Args: + user_choice (str): The user's choice of template. + + Returns: + str: The chosen template string. + + Raises: + ValueError: If the user_choice is not found in the templates. + """ + + templates = { + "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. + "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + } + + if user_choice in templates: + return templates[user_choice] + + raise ValueError(f"Template '{user_choice}' not found.") diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c2b3a758c7..df6907c3d5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -26,6 +26,7 @@ from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.chat_templates import chat_templates from axolotl.utils.dict import DictDefault LOG = logging.getLogger("axolotl") @@ -186,6 +187,12 @@ def load_tokenizer(cfg): LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + if cfg.chat_template: + tokenizer.chat_template = chat_templates(cfg.chat_template) + else: + LOG.info( + "No Chat template selected. Consider adding a chat template for easier inference." + ) return tokenizer From 3678a6c41d051ca6376d013c11c948e55b4c8b4f Mon Sep 17 00:00:00 2001 From: Tazik Shahjahan <35576188+taziksh@users.noreply.github.com> Date: Fri, 29 Dec 2023 14:15:53 -0800 Subject: [PATCH 28/41] Fix: bf16 support for inference (#981) * Fix: bf16 torch dtype * simplify casting to device and dtype --------- Co-authored-by: Wing Lian --- src/axolotl/cli/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index e6537ad052..85f6b358ac 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -103,7 +103,7 @@ def do_inference( importlib.import_module("axolotl.prompters"), prompter ) - model = model.to(cfg.device) + model = model.to(cfg.device, dtype=cfg.torch_dtype) while True: print("=" * 80) @@ -168,7 +168,7 @@ def do_inference_gradio( importlib.import_module("axolotl.prompters"), prompter ) - model = model.to(cfg.device) + model = model.to(cfg.device, dtype=cfg.torch_dtype) def generate(instruction): if not instruction: From 4d2e842e46bf8bd6dd0fda4d2667a7e7d80b4cd4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 1 Jan 2024 22:17:27 -0500 Subject: [PATCH 29/41] use recommended setting for use_reentrant w gradient checkpointing (#1021) * use recommended setting for use_reentrant w gradient checkpointing * add doc for gradient_checkpointing_kwargs --- README.md | 3 +++ src/axolotl/core/trainer_builder.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/README.md b/README.md index 98b8a78239..4dd80339a4 100644 --- a/README.md +++ b/README.md @@ -741,6 +741,9 @@ group_by_length: false # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing gradient_checkpointing: false +# additional kwargs to pass to the trainer for gradient checkpointing +# gradient_checkpointing_kwargs: +# use_reentrant: false # Stop training after this many evaluation losses have increased in a row # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fed26de464..4ca2877d19 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -566,6 +566,14 @@ def build(self, total_num_steps): training_arguments_kwargs[ "gradient_checkpointing" ] = self.cfg.gradient_checkpointing + if self.cfg.gradient_checkpointing_kwargs: + training_arguments_kwargs[ + "gradient_checkpointing_kwargs" + ] = self.cfg.gradient_checkpointing_kwargs + else: + training_arguments_kwargs["gradient_checkpointing_kwargs"] = { + "use_reentrant": False + } if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: From c75f91674566594c1ef68a922368eeeaac3bdf07 Mon Sep 17 00:00:00 2001 From: Tim Dolan <40906019+tdolan21@users.noreply.github.com> Date: Tue, 2 Jan 2024 20:00:37 -0500 Subject: [PATCH 30/41] added tiny llama examples for lora and qlora (#1027) * added tiny llama examples for lora and qlora * corrected yml files and removed tiny-llama.yml from llama-2 example --- examples/tiny-llama/README.md | 17 +++++ .../tiny-llama.yml => tiny-llama/lora.yml} | 9 +-- examples/tiny-llama/qlora.yml | 67 +++++++++++++++++++ 3 files changed, 87 insertions(+), 6 deletions(-) create mode 100644 examples/tiny-llama/README.md rename examples/{llama-2/tiny-llama.yml => tiny-llama/lora.yml} (87%) create mode 100644 examples/tiny-llama/qlora.yml diff --git a/examples/tiny-llama/README.md b/examples/tiny-llama/README.md new file mode 100644 index 0000000000..467c06ec87 --- /dev/null +++ b/examples/tiny-llama/README.md @@ -0,0 +1,17 @@ +# Overview + +This is a simple example of how to finetune TinyLlama1.1B using either lora or qlora: + +LoRa: + +``` +accelerate launch -m axolotl.cli.train examples/tiny-llama/lora.yml +``` + +qLoRa: + +``` +accelerate launch -m axolotl.cli.train examples/tiny-llama/qlora.yml +``` + +Both take about 10 minutes to complete on a 4090. diff --git a/examples/llama-2/tiny-llama.yml b/examples/tiny-llama/lora.yml similarity index 87% rename from examples/llama-2/tiny-llama.yml rename to examples/tiny-llama/lora.yml index c72db4e5b2..d72ce8eb44 100644 --- a/examples/llama-2/tiny-llama.yml +++ b/examples/tiny-llama/lora.yml @@ -1,5 +1,4 @@ -base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T - +base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer is_llama_derived_model: true @@ -17,6 +16,7 @@ output_dir: ./lora-out sequence_len: 4096 sample_packing: true +pad_to_sequence_len: true adapter: lora lora_model_dir: @@ -55,7 +55,6 @@ flash_attention: true warmup_steps: 10 evals_per_epoch: 4 -eval_table_size: saves_per_epoch: 1 debug: deepspeed: @@ -63,6 +62,4 @@ weight_decay: 0.0 fsdp: fsdp_config: special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" + diff --git a/examples/tiny-llama/qlora.yml b/examples/tiny-llama/qlora.yml new file mode 100644 index 0000000000..02af851adb --- /dev/null +++ b/examples/tiny-llama/qlora.yml @@ -0,0 +1,67 @@ +base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./qlora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: paged_adamw_32bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 4 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + From b31038aae90c296f108dde777037d223a1e5bfff Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 3 Jan 2024 11:56:19 +0900 Subject: [PATCH 31/41] chore(readme): update instruction to set config to load from cache (#1030) --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4dd80339a4..53dd46aacf 100644 --- a/README.md +++ b/README.md @@ -935,8 +935,9 @@ accelerate launch -m axolotl.cli.train your_config.yml You can optionally pre-tokenize dataset with the following before finetuning. This is recommended for large datasets. -- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface. -- Use `--debug` to see preprocessed examples. +- Set `dataset_prepared_path:` to a local folder for saving and loading pre-tokenized dataset. +- (Optional): Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface. +- (Optional): Use `--debug` to see preprocessed examples. ```bash python -m axolotl.cli.preprocess your_config.yml From a3e878332811ab1710d1e94217e2f3d914f0c48c Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Tue, 2 Jan 2024 21:35:20 -0800 Subject: [PATCH 32/41] [Docs] delete unused cfg value `lora_out_dir` (#1029) * Update README.md * Update README.md * Update README.md Co-authored-by: NanoCode012 --------- Co-authored-by: NanoCode012 --- README.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 53dd46aacf..172dd558e2 100644 --- a/README.md +++ b/README.md @@ -643,7 +643,8 @@ max_memory: # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model adapter: lora # If you already have a lora model trained that you want to load, put that here. -# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`. +# This means after training, if you want to test the model, you should set this to the value of `output_dir`. +# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`. lora_model_dir: # LoRA hyperparameters @@ -670,10 +671,6 @@ lora_modules_to_save: # - embed_tokens # - lm_head -# Once you complete training, the model will be saved to the following directory. -# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory. -# Make sure `lora_model_dir` points to this directory if you want to use the trained model. -lora_out_dir: lora_fan_in_fan_out: false # ReLoRA configuration From 8ba27f3bdec4bae372cdf76a9da7aadc6895b288 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 4 Jan 2024 00:23:44 +0900 Subject: [PATCH 33/41] fix: lint (#1037) --- examples/tiny-llama/lora.yml | 1 - examples/tiny-llama/qlora.yml | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/tiny-llama/lora.yml b/examples/tiny-llama/lora.yml index d72ce8eb44..53d50178a8 100644 --- a/examples/tiny-llama/lora.yml +++ b/examples/tiny-llama/lora.yml @@ -62,4 +62,3 @@ weight_decay: 0.0 fsdp: fsdp_config: special_tokens: - diff --git a/examples/tiny-llama/qlora.yml b/examples/tiny-llama/qlora.yml index 02af851adb..53791985ef 100644 --- a/examples/tiny-llama/qlora.yml +++ b/examples/tiny-llama/qlora.yml @@ -64,4 +64,3 @@ weight_decay: 0.0 fsdp: fsdp_config: special_tokens: - From 74532ddc458e6ca2c4e43771243e1b6b6fbb7813 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 4 Jan 2024 01:19:52 +0900 Subject: [PATCH 34/41] chore(config): clean up old log for Qwen (#1034) --- src/axolotl/utils/config.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index d9e56b95a6..9bade45728 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -422,11 +422,6 @@ def validate_config(cfg): if cfg.warmup_steps and cfg.warmup_ratio: raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") - if cfg.is_qwen_derived_model and cfg.gradient_checkpointing: - LOG.warning( - "Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch." - ) - if cfg.wandb_run_id and not cfg.wandb_name: cfg.wandb_name = cfg.wandb_run_id From bcc78d8fa393ae07f4df364d1104d63cf778c9e1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 3 Jan 2024 15:11:04 -0500 Subject: [PATCH 35/41] bump transformers and update attention class map name (#1023) * bump transformers and update attention class map name * also run the tests in docker * add mixtral e2e smoke test * fix base name for docker image in test * mixtral lora doesn't seem to work, at least check qlora * add testcase for mixtral w sample packing * check monkeypatch for flash attn multipack * also run the e2e tests in docker * use all gpus to run tests in docker ci * use privileged mode too for docker w gpus * rename the docker e2e actions for gh ci * set privileged mode for docker and update mixtral model self attn check * use fp16/bf16 for mixtral w fa2 * skip e2e tests on docker w gpus for now * tests to validate mistral and mixtral patches * fix rel import --- .github/workflows/tests-docker.yml | 62 +++++++++ requirements.txt | 2 +- src/axolotl/monkeypatch/mixtral/__init__.py | 2 +- .../monkeypatch/mixtral/modeling_mixtral.py | 8 +- src/axolotl/utils/models.py | 3 + tests/e2e/test_mixtral.py | 109 ++++++++++++++++ tests/e2e/test_mixtral_samplepack.py | 123 ++++++++++++++++++ tests/e2e/test_model_patches.py | 99 ++++++++++++++ 8 files changed, 404 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/tests-docker.yml create mode 100644 tests/e2e/test_mixtral.py create mode 100644 tests/e2e/test_mixtral_samplepack.py create mode 100644 tests/e2e/test_model_patches.py diff --git a/.github/workflows/tests-docker.yml b/.github/workflows/tests-docker.yml new file mode 100644 index 0000000000..ff30d68ea2 --- /dev/null +++ b/.github/workflows/tests-docker.yml @@ -0,0 +1,62 @@ +name: e2e-docker-tests + +on: + pull_request: + paths: + - '**.py' + - 'requirements.txt' + workflow_dispatch: + +jobs: + build-axolotl: + if: github.repository_owner == 'OpenAccess-AI-Collective' + # this job needs to be run on self-hosted GPU runners... + strategy: + fail-fast: false + matrix: + include: + - cuda: 118 + cuda_version: 11.8.0 + python_version: "3.10" + pytorch: 2.0.1 + axolotl_extras: + is_latest: true + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.10" + pytorch: 2.1.1 + axolotl_extras: + runs-on: [self-hosted, gpu, docker] + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Docker metadata + id: metadata + uses: docker/metadata-action@v5 + with: + images: winglian/axolotl + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + # guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/ + - name: Build and export to Docker + uses: docker/build-push-action@v5 + with: + context: . + load: true + build-args: | + BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} + CUDA=${{ matrix.cuda }} + PYTORCH_VERSION=${{ matrix.pytorch }} + file: ./docker/Dockerfile + tags: | + ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} + ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} + labels: ${{ steps.metadata.outputs.labels }} + - name: Unit Tests + run: | + docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ diff --git a/requirements.txt b/requirements.txt index c1c1cbc132..f4df0dd671 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ auto-gptq==0.5.1 packaging peft==0.6.0 -transformers==4.36.2 +transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0 tokenizers==0.15.0 bitsandbytes>=0.41.1 accelerate==0.24.1 diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index 4188146892..74fa00f649 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -17,6 +17,6 @@ def replace_mixtral_attn_with_multipack_flash_attn(): transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = ( mixtral_model_forward ) - transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[ + transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[ "flash_attention_2" ] = MixtralMultipackFlashAttention2 diff --git a/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py b/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py index 34f35015f9..db892530d6 100644 --- a/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py +++ b/src/axolotl/monkeypatch/mixtral/modeling_mixtral.py @@ -261,7 +261,11 @@ def mixtral_model_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if ( + attention_mask is not None + and self._attn_implementation == "flash_attention_2" + and use_cache + ): is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -270,7 +274,7 @@ def mixtral_model_forward( " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = ( attention_mask diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index df6907c3d5..fb2420108a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -332,15 +332,18 @@ def load_model( or cfg.is_mistral_derived_model or model_config.model_type == "mixtral" ): + model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" ) else: if model_config.model_type == "mixtral": + model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" ) else: + model_kwargs["attn_implementation"] = "eager" model_config._attn_implementation = ( # pylint: disable=protected-access "eager" ) diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py new file mode 100644 index 0000000000..896cc74d0f --- /dev/null +++ b/tests/e2e/test_mixtral.py @@ -0,0 +1,109 @@ +""" +E2E tests for mixtral +""" + +import logging +import os +import unittest +from pathlib import Path + +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestMixtral(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_qlora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 1024, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 16, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_ft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/tests/e2e/test_mixtral_samplepack.py b/tests/e2e/test_mixtral_samplepack.py new file mode 100644 index 0000000000..b43702a512 --- /dev/null +++ b/tests/e2e/test_mixtral_samplepack.py @@ -0,0 +1,123 @@ +""" +E2E tests for mixtral +""" + +import logging +import os +import unittest +from pathlib import Path + +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestMixtral(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_qlora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 2048, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 16, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "sample_packing": True, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_ft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 2048, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "sample_packing": True, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert ( + "axolotl.monkeypatch.mixtral.modeling_mixtral" + in model.model.layers[0].self_attn.__class__.__module__ + ) + assert ( + "MixtralMultipackFlashAttention2" + in model.model.layers[0].self_attn.__class__.__name__ + ) + assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/tests/e2e/test_model_patches.py b/tests/e2e/test_model_patches.py new file mode 100644 index 0000000000..eb11244644 --- /dev/null +++ b/tests/e2e/test_model_patches.py @@ -0,0 +1,99 @@ +""" +E2E smoke tests to check that the monkeypatches are in place for certain configurations +""" + +import unittest + +from axolotl.common.cli import TrainerCliArgs +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + +from .utils import with_temp_dir + + +class TestModelPatches(unittest.TestCase): + """ + TestCases for the multipack monkey patches + """ + + @with_temp_dir + def test_mixtral_multipack(self, temp_dir): + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sample_packing": True, + "sequence_len": 2048, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + tokenizer = load_tokenizer(cfg) + model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + + assert ( + "axolotl.monkeypatch.mixtral.modeling_mixtral" + in model.model.layers[0].self_attn.__class__.__module__ + ) + assert ( + "MixtralMultipackFlashAttention2" + in model.model.layers[0].self_attn.__class__.__name__ + ) + + @with_temp_dir + def test_mistral_multipack(self, temp_dir): + cfg = DictDefault( + { + "base_model": "openaccess-ai-collective/tiny-mistral", + "flash_attention": True, + "sample_packing": True, + "sequence_len": 2048, + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + tokenizer = load_tokenizer(cfg) + model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + + assert ( + "axolotl.monkeypatch.mistral_attn_hijack_flash" + in model.model.layers[0].self_attn.forward.__module__ + ) From 59b2d302c8780ed83e6a0201b741574ee51a1a5e Mon Sep 17 00:00:00 2001 From: xaviviro Date: Thu, 4 Jan 2024 13:03:04 +0100 Subject: [PATCH 36/41] Added chatglm3 conversation type for training models like TinyLLama (#1036) * Added chatgml3 conversation type for training models like TinyLLama * Added chatgml3 conversation type for training models like TinyLLama with lint * Added chatgml3 conversation type for training models like TinyLLama with lint --- src/axolotl/monkeypatch/fastchat_conversation_turns.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index 068261da36..aafdabe547 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -147,6 +147,15 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role + "\n", "" return + if self.sep_style == SeparatorStyle.CHATGLM3: + if self.system_message: + yield "", system_prompt + for role, message in self.messages: + if message: + yield role + "\n", " " + message + else: + yield role + return if self.sep_style == SeparatorStyle.CHATINTERN: # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 seps = [self.sep, self.sep2] From f243c2186d1575d25fd2b62ef7f3e3d22fd03db6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 4 Jan 2024 18:21:25 -0500 Subject: [PATCH 37/41] RL/DPO (#935) * ipo-dpo trainer * fix missing abstract method * chatml template, grad checkpointing kwargs support * fix steps calc for RL and add dataloader kwargs * wip to fix dpo and start ppo * more fixes * refactor to generalize map fn * fix dataset loop and handle argilla pref dataset * set training args * load reference model on seperate gpu if more than one device * no auto upload to hub for dpo, don't add lora adapters to ref model for dpo * fixes for rl training * support for ipo from yaml * set dpo training args from the config, add tests * chore: lint * set sequence_len for model in test * add RLHF docs --- docs/rlhf.md | 35 +++++++++ requirements.txt | 2 + src/axolotl/cli/__init__.py | 90 ++++++++++++++++++++++ src/axolotl/cli/train.py | 6 +- src/axolotl/core/trainer_builder.py | 103 ++++++++++++++++++++++++++ src/axolotl/core/trainers/__init__.py | 0 src/axolotl/core/trainers/trl.py | 66 +++++++++++++++++ src/axolotl/train.py | 8 +- src/axolotl/utils/models.py | 16 +++- src/axolotl/utils/trainer.py | 9 ++- tests/core/test_trainer_builder.py | 59 +++++++++++++++ 11 files changed, 388 insertions(+), 6 deletions(-) create mode 100644 docs/rlhf.md create mode 100644 src/axolotl/core/trainers/__init__.py create mode 100644 src/axolotl/core/trainers/trl.py create mode 100644 tests/core/test_trainer_builder.py diff --git a/docs/rlhf.md b/docs/rlhf.md new file mode 100644 index 0000000000..371a40dbf7 --- /dev/null +++ b/docs/rlhf.md @@ -0,0 +1,35 @@ +# RLHF (Beta) + +### Overview + +Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human +feedback. Various methods include, but not limited to: + +- Proximal Policy Optimization (PPO) (not yet supported in axolotl) +- Direct Preference Optimization (DPO) +- Identity Preference Optimization (IPO) + + +### RLHF using Axolotl + +[!IMPORTANT] +This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality. + +The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML + +#### DPO +```yaml +rl: true +datasets: + - path: Intel/orca_dpo_pairs + split: train + type: intel_apply_chatml + - path: argilla/ultrafeedback-binarized-preferences + split: train + type: argilla_apply_chatml +``` + +#### IPO +```yaml +rl: ipo +``` diff --git a/requirements.txt b/requirements.txt index f4df0dd671..14f6633f7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,5 @@ tensorboard s3fs gcsfs # adlfs + +trl @ git+https://github.com/huggingface/trl.git@main diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 85f6b358ac..0477ebebfb 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -2,6 +2,7 @@ import importlib import logging +import math import os import random import sys @@ -16,6 +17,7 @@ # add src to the pythonpath so we don't need to pip install this from accelerate.commands.config import config_args from art import text2art +from datasets import concatenate_datasets, load_dataset from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer @@ -325,6 +327,94 @@ def load_datasets( ) +def load_rl_datasets( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, # pylint: disable=unused-argument +) -> TrainDatasetMeta: + train_datasets: List[Any] = [] + for i, ds_cfg in enumerate(cfg.datasets): + train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"])) + # eval_dataset = load_dataset( + # cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"] + # ) + eval_dataset = None + + def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen_response']}<|im_end|>" + sample["rejected"] = f"{sample['rejected_response']}<|im_end|>" + return sample + + def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + def apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" + sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" + return sample + + for i, data_set in enumerate(train_datasets): + _type = cfg.datasets[i]["type"] + ds_type_fn = locals()[_type] + train_datasets[i] = data_set.map(ds_type_fn) + train_dataset = concatenate_datasets(train_datasets) + + # eval_dataset = eval_dataset.map(intel_apply_chatml) + + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) + + def check_accelerate_default_config(): if Path(config_args.default_yaml_config_file).exists(): LOG.warning( diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 81307b6b92..2248784dff 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -12,6 +12,7 @@ check_user_token, load_cfg, load_datasets, + load_rl_datasets, print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs @@ -30,7 +31,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs): parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) - dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if parsed_cfg.rl: + dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + else: + dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4ca2877d19..1ca36eb418 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -20,6 +20,7 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_utils import seed_worker +from trl import DPOTrainer from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( @@ -420,12 +421,21 @@ class TrainerBuilderBase(abc.ABC): _train_dataset = None _eval_dataset = None + _model_ref = None def __init__(self, cfg, model, tokenizer): self.cfg = cfg self.model = model self.tokenizer = tokenizer + @property + def model_ref(self): + return self._model_ref + + @model_ref.setter + def model_ref(self, model): + self._model_ref = model + @property def train_dataset(self): return self._train_dataset @@ -827,3 +837,96 @@ def build_collator(self, **kwargs): return_tensors="pt", **kwargs, ) + + +class HFDPOTrainerBuilder(TrainerBuilderBase): + """ + Trainer factory class for DPO Trainer + """ + + def get_callbacks(self): + callbacks = [] + return callbacks + + def get_post_trainer_create_callbacks(self, trainer): + callbacks = [] + return callbacks + + def build_training_arguments(self, total_num_steps): + training_args_kwargs = {} + for arg in [ + "adam_beta1", + "adam_beta2", + "adam_epsilon", + "dataloader_num_workers", + "dataloader_pin_memory", + ]: + if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: + training_args_kwargs[arg] = getattr(self.cfg, arg) + training_args = TrainingArguments( + per_device_train_batch_size=self.cfg.micro_batch_size, + max_steps=total_num_steps, + remove_unused_columns=False, + gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, + learning_rate=self.cfg.learning_rate, + evaluation_strategy="no", + # eval_steps=self.cfg.eval_steps, + save_strategy="steps", + save_steps=self.cfg.save_steps, + output_dir=self.cfg.output_dir, + warmup_steps=self.cfg.warmup_steps, + bf16=True, + gradient_checkpointing=self.cfg.gradient_checkpointing, + gradient_checkpointing_kwargs={"use_reentrant": False}, + logging_first_step=True, + logging_steps=1, + optim=self.cfg.optimizer, + save_total_limit=self.cfg.save_total_limit or 5, + **training_args_kwargs, + ) + + return training_args + + def build(self, total_num_steps): + training_args = self.build_training_arguments(total_num_steps) + dpo_trainer_kwargs = {} + if self.cfg.rl == "ipo": + dpo_trainer_kwargs["loss_type"] = "ipo" + if self.cfg.dpo_label_smoothing: + dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing + + dpo_trainer = DPOTrainer( + self.model, + self.model_ref, + args=training_args, + beta=self.cfg.dpo_beta or 0.1, + train_dataset=self.train_dataset, + # eval_dataset=self.eval_dataset, + eval_dataset=None, + tokenizer=self.tokenizer, + max_length=self.cfg.sequence_len, + max_target_length=None, + max_prompt_length=self.cfg.sequence_len, + generate_during_eval=True, + **dpo_trainer_kwargs, + ) + + return dpo_trainer + + +class HFPPOTrainerBuilder(TrainerBuilderBase): + """ + HF Factory class for PPO Trainer + """ + + def get_callbacks(self): + callbacks = [] + return callbacks + + def get_post_trainer_create_callbacks(self, trainer): + callbacks = [] + return callbacks + + def build(self, total_num_steps): + # build PPOConfig + pass diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py new file mode 100644 index 0000000000..24c0b04123 --- /dev/null +++ b/src/axolotl/core/trainers/trl.py @@ -0,0 +1,66 @@ +""" +module for TRL PPO training +""" +import torch +from tqdm import tqdm +from trl import PPOTrainer + + +class TRLPPOTrainer(PPOTrainer): + """ + wrapper for ppo trainer to handle customizations + """ + + def train( + self, + reward_pipe, + resume_from_checkpoint=None, # pylint: disable=unused-argument + ): + generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": self.tokenizer.eos_token_id, + "max_new_tokens": 32, + } + sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "batch_size": 16, + } + + for epoch, batch in tqdm( # pylint: disable=unused-variable + enumerate(self.dataloader) + ): + query_tensors = batch["input_ids"] + + # generate model response + response_tensors, ref_response_tensors = self.generate( + query_tensors, + return_prompt=False, + generate_ref_response=True, + **generation_kwargs + ) + batch["response"] = self.tokenizer.batch_decode(response_tensors) + batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors) + + # Compute sentiment score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = reward_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] + ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])] + ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs) + ref_rewards = [ + torch.tensor(output[1]["score"]) for output in ref_pipe_outputs + ] + batch["ref_rewards"] = ref_rewards + + # Run PPO step + stats = self.step(query_tensors, response_tensors, rewards) + self.log_stats( + stats, + batch, + rewards, + columns_to_log=["query", "response", "ref_response", "ref_rewards"], + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 4e5241e4c8..e0da112528 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -61,6 +61,12 @@ def train( msg += " and peft_config..." LOG.debug(msg) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) + model_ref = None + if cfg.rl: + # load the model again for model_ref/baseline + model_ref, _ = load_model( + cfg, tokenizer, inference=cli_args.inference, reference_model=True + ) safe_serialization = cfg.save_safetensors is True @@ -83,7 +89,7 @@ def train( freeze_parameters_except(model, cfg.unfrozen_parameters) trainer = setup_trainer( - cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps + cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps ) if hasattr(model, "config"): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fb2420108a..b30ffcad8c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -200,6 +200,7 @@ def load_model( cfg: DictDefault, tokenizer: PreTrainedTokenizerBase, inference: bool = False, + reference_model: bool = False, ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: """ Load a model for a given configuration and tokenizer. @@ -290,6 +291,15 @@ def load_model( model_kwargs["device_map"] = cfg.device_map model_kwargs["max_memory"] = cfg.max_memory model_kwargs["torch_dtype"] = cfg.torch_dtype + # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss + # if cfg.rl: + # if torch.cuda.device_count() > 1: + # if reference_model: + # model_kwargs["device_map"] = "cuda:" + str( + # torch.cuda.current_device() + 1 + # ) + # else: + # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) if is_deepspeed_zero3_enabled(): del model_kwargs["device_map"] @@ -560,9 +570,11 @@ def load_model( if hasattr(module, "weight"): module.to(cfg.torch_dtype) - model, lora_config = load_adapter(model, cfg, cfg.adapter) + lora_config = None + if not reference_model or cfg.lora_model_dir: + model, lora_config = load_adapter(model, cfg, cfg.adapter) - if cfg.ddp and not load_in_8bit: + if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): model.to(f"cuda:{cfg.local_rank}") if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f046dd7be8..d975bb9a2d 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -12,7 +12,7 @@ from datasets import set_caching_enabled from torch.utils.data import DataLoader, RandomSampler -from axolotl.core.trainer_builder import HFCausalTrainerBuilder +from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.samplers import MultipackBatchSampler @@ -280,7 +280,12 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) + if cfg.rl: + trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder.model_ref = model[1] + else: + trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py new file mode 100644 index 0000000000..e8987ef452 --- /dev/null +++ b/tests/core/test_trainer_builder.py @@ -0,0 +1,59 @@ +""" +unit tests for axolotl.core.trainer_builder +""" +import pytest + +from axolotl.core.trainer_builder import HFDPOTrainerBuilder +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + + +@pytest.fixture(name="cfg") +def fixture_cfg(): + return DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "model_type": "AutoModelForCausalLM", + "tokenizer_type": "LlamaTokenizer", + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 0.00005, + "save_steps": 100, + "output_dir": "./model-out", + "warmup_steps": 10, + "gradient_checkpointing": False, + "optimizer": "adamw_torch", + "sequence_len": 2048, + "rl": True, + "adam_beta1": 0.998, + "adam_beta2": 0.9, + "adam_epsilon": 0.00001, + "dataloader_num_workers": 1, + "dataloader_pin_memory": True, + } + ) + + +@pytest.fixture(name="tokenizer") +def fixture_tokenizer(cfg): + return load_tokenizer(cfg) + + +@pytest.fixture(name="model") +def fixture_model(cfg, tokenizer): + return load_model(cfg, tokenizer) + + +class TestHFDPOTrainerBuilder: + """ + TestCase class for DPO trainer builder + """ + + def test_build_training_arguments(self, cfg, model, tokenizer): + builder = HFDPOTrainerBuilder(cfg, model, tokenizer) + training_arguments = builder.build_training_arguments(100) + assert training_arguments.adam_beta1 == 0.998 + assert training_arguments.adam_beta2 == 0.9 + assert training_arguments.adam_epsilon == 0.00001 + assert training_arguments.dataloader_num_workers == 1 + assert training_arguments.dataloader_pin_memory is True From 98247539b9fdbb1ee6f38f2a010bb8e0454e0a3d Mon Sep 17 00:00:00 2001 From: "jinwonkim93@github.com" Date: Fri, 15 Dec 2023 03:14:40 +0000 Subject: [PATCH 38/41] [Feat] streaming multipack --- src/axolotl/utils/data.py | 70 +++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 10 ++++ tests/test_packed_pretraining.py | 87 ++++++++++++++++++++++++++++++++ 3 files changed, 167 insertions(+) create mode 100644 tests/test_packed_pretraining.py diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 5c41d16fe4..165e2c8fd0 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -2,9 +2,11 @@ import functools import hashlib import logging +from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple, Union +import numpy as np import torch from datasets import ( Dataset, @@ -14,6 +16,7 @@ load_from_disk, ) from huggingface_hub import hf_hub_download +from torch.utils.data import RandomSampler from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH @@ -41,9 +44,11 @@ ) from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.samplers.multipack import MultipackBatchSampler from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, + process_pretraining_datasets_for_packing, ) LOG = logging.getLogger("axolotl") @@ -819,3 +824,68 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): remove_columns=dataset.features.keys(), ) return dataset + + +def encode_packed_pretraining( + tokenizer: PreTrainedTokenizerBase, + examples: List[str], + max_seq_length: int = 8192, + sample_packing_efficiency: int = 1, +) -> Dict[str, List]: + # pylint: disable=duplicate-code + # tokenize all the examples + # rows get split with stride (overlap) + res = tokenizer( + examples, + truncation=True, + max_length=max_seq_length - 1, + add_special_tokens=True, + return_overflowing_tokens=True, + stride=256, + ) + + input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]] + attention_mask = [seq + [1] for seq in res["attention_mask"]] + + tokenized_examples = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + train_dataset = Dataset.from_dict(tokenized_examples) + train_dataset = process_pretraining_datasets_for_packing( + train_dataset, max_seq_length + ) + + sampler = MultipackBatchSampler( + RandomSampler(train_dataset), + batch_size=1, + drop_last=True, + batch_max_len=max_seq_length, + lengths=( + train_dataset.data.column("position_ids") + .to_pandas() + .apply(lambda x: x[-1] + 1) + .values + ), + packing_efficiency_estimate=sample_packing_efficiency, + ) + + chunked_data = defaultdict(list) + + for data in sampler: + features = train_dataset[data] + + for feature in features.keys(): + if feature == "length": + continue + if feature == "attention_mask": + arrays = [(1) * np.array(item) for item in features[feature]] + chunked_data[feature].append(np.concatenate(arrays)) + else: + arrays = [np.array(item) for item in features[feature]] + chunked_data[feature].append(np.concatenate(arrays)) + + chunked_data["labels"] = chunked_data["input_ids"].copy() + + return chunked_data diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d975bb9a2d..3139f56004 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -143,6 +143,16 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): return train_dataset, eval_dataset +def process_pretraining_datasets_for_packing(train_dataset, sequence_len): + drop_long = partial(drop_long_seq, sequence_len=sequence_len) + + train_dataset = train_dataset.filter(drop_long) + train_dataset = train_dataset.map( + add_position_ids, + ) + return train_dataset + + def calculate_total_num_steps(cfg, train_dataset, update=True): if not cfg.total_num_tokens: total_num_tokens = np.sum( diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py new file mode 100644 index 0000000000..9dee6ed7f0 --- /dev/null +++ b/tests/test_packed_pretraining.py @@ -0,0 +1,87 @@ +"""Module for testing streaming dataset sequence packing""" +import math +import unittest +from functools import partial + +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +from axolotl.utils.collators import DataCollatorForSeq2Seq +from axolotl.utils.data import encode_packed_pretraining + + +class TestPacking(unittest.TestCase): + """ + Test class for packing streaming dataset sequences + """ + + def setUp(self) -> None: + # pylint: disable=duplicate-code + self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + "pad_token": "[PAD]", + } + ) + self.max_seq_length = 8192 + self.batch_size = 6 + self.sample_packing_efficiency = 1 + self.data_collator_kwargs = { + "padding": True, + "pad_to_multiple_of": 64 * math.ceil(self.max_seq_length / 64), + } + + def test_packing_stream_dataset(self): + # pylint: disable=duplicate-code + dataset = load_dataset( + "c4", + "en", + streaming=True, + )["train"] + + encode = partial( + encode_packed_pretraining, + self.tokenizer, + max_seq_length=self.max_seq_length, + sample_packing_efficiency=self.sample_packing_efficiency, + ) + + dataset = dataset.map( + encode, + batched=True, + input_columns="text", + remove_columns=dataset.features.keys(), + ) + + data_collator_fn = DataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + **self.data_collator_kwargs, + ) + + trainer_loader = DataLoader( + dataset, + batch_size=self.batch_size, + collate_fn=data_collator_fn, + drop_last=True, + ) + idx = 0 + for data in trainer_loader: + if idx > 10: + break + assert data["input_ids"].shape == (self.batch_size, self.max_seq_length) + assert data["position_ids"].shape == (self.batch_size, self.max_seq_length) + assert data["labels"].shape == (self.batch_size, self.max_seq_length) + assert data["attention_mask"].shape == ( + self.batch_size, + self.max_seq_length, + ) + idx += 1 + + +if __name__ == "__main__": + unittest.main() From d05fe8ca97a09c7fe1082acad860fe3763dca2a3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 16:04:37 +0000 Subject: [PATCH 39/41] WIP make continued pretraining work w multipack --- examples/tiny-llama/pretrain.yml | 56 +++++++++++++++++++++++++++++ src/axolotl/cli/train.py | 2 ++ src/axolotl/core/trainer_builder.py | 6 ++-- src/axolotl/utils/collators.py | 26 ++++++++++++++ src/axolotl/utils/data.py | 37 ++++++++++--------- 5 files changed, 107 insertions(+), 20 deletions(-) create mode 100644 examples/tiny-llama/pretrain.yml diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml new file mode 100644 index 0000000000..b0319a95f2 --- /dev/null +++ b/examples/tiny-llama/pretrain.yml @@ -0,0 +1,56 @@ +base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +max_steps: 200 +pretraining_dataset: c4 +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./model-out + +sequence_len: 2048 +sample_packing: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 2248784dff..54242dd583 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -5,6 +5,7 @@ from pathlib import Path import fire +import torch import transformers from axolotl.cli import ( @@ -19,6 +20,7 @@ from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") +# torch.set_printoptions(threshold=10000) def do_cli(config: Path = Path("examples/"), **kwargs): diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 1ca36eb418..ab6cf9d3f0 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -157,7 +157,7 @@ def create_scheduler( return self.lr_scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing: + if self.args.sample_packing and False: return MultipackBatchSampler( RandomSampler(self.train_dataset), self.args.train_batch_size, @@ -193,7 +193,7 @@ def _get_eval_sampler( return super()._get_eval_sampler(eval_dataset) def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing: + if self.args.sample_packing and False: train_dataset = self.train_dataset train_dataset = train_dataset.remove_columns(["length"]) data_collator = self.data_collator @@ -807,7 +807,7 @@ def build(self, total_num_steps): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - data_collator=self.build_collator(**data_collator_kwargs), + # data_collator=self.build_collator(**data_collator_kwargs), bench_data_collator=transformers.DataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index 0f0eb5a95a..b4c4fa4dff 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -178,3 +178,29 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: "input_ids": input_ids, "labels": labels, } + +@dataclass +class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + """ + Collator for multipack specific to the using the BatchSampler + """ + + def __call__(self, features, return_tensors=None): + chunked_data = {} + for feature in features.keys(): + if feature == "length": + continue + if feature == "attention_mask": + arrays = [ + (1) * np.array(item) + for item in features[feature] + ] + chunked_data[feature] = np.concatenate(arrays) + else: + arrays = [ + np.array(item) for item in features[feature] + ] + chunked_data[feature] = np.concatenate(arrays) + features = [chunked_data] + return super().__call__(features, return_tensors=return_tensors) + diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 165e2c8fd0..199d4a64ab 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union, Optional import numpy as np import torch @@ -42,6 +42,7 @@ SummarizeTLDRPrompter, UnsupportedPrompter, ) +from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.samplers.multipack import MultipackBatchSampler @@ -72,6 +73,7 @@ def prepare_dataset(cfg, tokenizer): train_dataset = load_pretraining_dataset( cfg.pretraining_dataset, tokenizer, + cfg, max_tokens=cfg.sequence_len, seed=cfg.seed or 42, ) @@ -811,9 +813,15 @@ def encode_pretraining( return ret -def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): - encode = functools.partial(encode_pretraining, tokenizer, max_tokens) - dataset = load_dataset(path, streaming=True, split="train") +def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): + if cfg.sample_packing: + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens) + encode = functools.partial(encode_packed_pretraining, tokenizer, collate_fn, max_seq_length=max_tokens, batch_size=cfg.micro_batch_size) + cfg.micro_batch_size = 1 + else: + encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + + dataset = load_dataset(path, streaming=True, split="train", name="en") dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.map( encode, @@ -828,9 +836,10 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): def encode_packed_pretraining( tokenizer: PreTrainedTokenizerBase, + collate_fn, examples: List[str], - max_seq_length: int = 8192, - sample_packing_efficiency: int = 1, + max_seq_length: int = 2048, + batch_size: int = 4, ) -> Dict[str, List]: # pylint: disable=duplicate-code # tokenize all the examples @@ -859,33 +868,27 @@ def encode_packed_pretraining( sampler = MultipackBatchSampler( RandomSampler(train_dataset), - batch_size=1, + batch_size=batch_size, drop_last=True, - batch_max_len=max_seq_length, + batch_max_len=batch_size * max_seq_length, lengths=( train_dataset.data.column("position_ids") .to_pandas() .apply(lambda x: x[-1] + 1) .values ), - packing_efficiency_estimate=sample_packing_efficiency, ) chunked_data = defaultdict(list) for data in sampler: features = train_dataset[data] + features["labels"] = features["input_ids"].copy() + collated_features = collate_fn(features) for feature in features.keys(): if feature == "length": continue - if feature == "attention_mask": - arrays = [(1) * np.array(item) for item in features[feature]] - chunked_data[feature].append(np.concatenate(arrays)) - else: - arrays = [np.array(item) for item in features[feature]] - chunked_data[feature].append(np.concatenate(arrays)) - - chunked_data["labels"] = chunked_data["input_ids"].copy() + chunked_data[feature].append(collated_features[feature].squeeze(0)) return chunked_data From 0ac242f4b9e67c5edd0cf8c08d082586978d904d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 11:15:43 -0500 Subject: [PATCH 40/41] fix up hadrcoding, lint --- examples/tiny-llama/pretrain.yml | 6 ++++-- src/axolotl/cli/train.py | 2 -- src/axolotl/core/trainer_builder.py | 11 +++++++++-- src/axolotl/utils/collators.py | 11 +++-------- src/axolotl/utils/data.py | 29 ++++++++++++++++++++++------- 5 files changed, 38 insertions(+), 21 deletions(-) diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml index b0319a95f2..dfd1bfca29 100644 --- a/examples/tiny-llama/pretrain.yml +++ b/examples/tiny-llama/pretrain.yml @@ -9,7 +9,9 @@ load_in_4bit: false strict: false max_steps: 200 -pretraining_dataset: c4 +pretraining_dataset: + path: c4 + name: en dataset_prepared_path: val_set_size: 0.0 output_dir: ./model-out @@ -45,7 +47,7 @@ xformers_attention: flash_attention: true warmup_steps: 10 -evals_per_epoch: +evals_per_epoch: eval_table_size: saves_per_epoch: 1 debug: diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 54242dd583..2248784dff 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -5,7 +5,6 @@ from pathlib import Path import fire -import torch import transformers from axolotl.cli import ( @@ -20,7 +19,6 @@ from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") -# torch.set_printoptions(threshold=10000) def do_cli(config: Path = Path("examples/"), **kwargs): diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ab6cf9d3f0..b75766d043 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -60,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "Use quadratic warmup for cosine scheduling."}, ) + pretraining: bool = field( + default=False, + metadata={ + "help": "Indicates to trainer whether we are doing continued pretraining." + }, + ) sample_packing: bool = field( default=False, metadata={"help": "Use sample packing for efficient training."}, @@ -157,7 +163,7 @@ def create_scheduler( return self.lr_scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and False: + if self.args.sample_packing and not self.args.pretraining: return MultipackBatchSampler( RandomSampler(self.train_dataset), self.args.train_batch_size, @@ -193,7 +199,7 @@ def _get_eval_sampler( return super()._get_eval_sampler(eval_dataset) def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing and False: + if self.args.sample_packing and not self.args.pretraining: train_dataset = self.train_dataset train_dataset = train_dataset.remove_columns(["length"]) data_collator = self.data_collator @@ -767,6 +773,7 @@ def build(self, total_num_steps): training_arguments_kwargs ) training_arguments_kwargs["model_type"] = self.cfg.model_config_type + training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.neftune_noise_alpha is not None: training_arguments_kwargs[ diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index b4c4fa4dff..b9c1c3b3c1 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -179,6 +179,7 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: "labels": labels, } + @dataclass class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): """ @@ -191,16 +192,10 @@ def __call__(self, features, return_tensors=None): if feature == "length": continue if feature == "attention_mask": - arrays = [ - (1) * np.array(item) - for item in features[feature] - ] + arrays = [(1) * np.array(item) for item in features[feature]] chunked_data[feature] = np.concatenate(arrays) else: - arrays = [ - np.array(item) for item in features[feature] - ] + arrays = [np.array(item) for item in features[feature]] chunked_data[feature] = np.concatenate(arrays) features = [chunked_data] return super().__call__(features, return_tensors=return_tensors) - diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 199d4a64ab..0de5d4ee32 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -4,9 +4,8 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple, Union, Optional +from typing import Dict, List, Tuple, Union -import numpy as np import torch from datasets import ( Dataset, @@ -42,7 +41,7 @@ SummarizeTLDRPrompter, UnsupportedPrompter, ) -from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.samplers.multipack import MultipackBatchSampler @@ -70,10 +69,17 @@ def prepare_dataset(cfg, tokenizer): tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) else: + path = cfg.pretraining_dataset + name = None + if isinstance(dict, cfg.pretraining_dataset): + path = cfg.pretraining_dataset.path + name = cfg.pretraining_dataset.name + train_dataset = load_pretraining_dataset( - cfg.pretraining_dataset, + path, tokenizer, cfg, + name=name, max_tokens=cfg.sequence_len, seed=cfg.seed or 42, ) @@ -815,13 +821,22 @@ def encode_pretraining( def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): if cfg.sample_packing: - collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens) - encode = functools.partial(encode_packed_pretraining, tokenizer, collate_fn, max_seq_length=max_tokens, batch_size=cfg.micro_batch_size) + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens + ) + encode = functools.partial( + encode_packed_pretraining, + tokenizer, + collate_fn, + max_seq_length=max_tokens, + batch_size=cfg.micro_batch_size, + ) + # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 else: encode = functools.partial(encode_pretraining, tokenizer, max_tokens) - dataset = load_dataset(path, streaming=True, split="train", name="en") + dataset = load_dataset(path, streaming=True, split="train", name=name) dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.map( encode, From a5eb52e9f0ad86241fe5a5735d41fbd265d01a8f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 11:22:34 -0500 Subject: [PATCH 41/41] fix dict check --- src/axolotl/utils/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 0de5d4ee32..40a3306021 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -71,9 +71,9 @@ def prepare_dataset(cfg, tokenizer): else: path = cfg.pretraining_dataset name = None - if isinstance(dict, cfg.pretraining_dataset): - path = cfg.pretraining_dataset.path - name = cfg.pretraining_dataset.name + if isinstance(cfg.pretraining_dataset, dict): + path = cfg.pretraining_dataset["path"] + name = cfg.pretraining_dataset["name"] train_dataset = load_pretraining_dataset( path,