Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ia3 peft support #601

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

add ia3 peft support #601

wants to merge 10 commits into from

Conversation

winglian
Copy link
Collaborator

No description provided.

@winglian
Copy link
Collaborator Author

we can support 4-bit IA3 once huggingface/peft#864 is merged.

Copy link
Collaborator

@NanoCode012 NanoCode012 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not used IA3 before, but here's my comments from looking at the linked PR.

src/axolotl/utils/models.py Outdated Show resolved Hide resolved
src/axolotl/utils/models.py Outdated Show resolved Hide resolved
src/axolotl/utils/models.py Outdated Show resolved Hide resolved
src/axolotl/utils/models.py Outdated Show resolved Hide resolved
if (
(cfg.adapter == "lora" and cfg.load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
or (cfg.adapter == "ia3" and cfg.load_in_8bit)
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second point, is ia3 load_in_8bit or 4bit? The linked PR seems to be 4bit addition but also support 8bit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's 8 bit only for now. I added some checks to warn in the config validation.

examples/llama-2/ia3.yml Outdated Show resolved Hide resolved
Copy link
Collaborator

@NanoCode012 NanoCode012 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will need to run this myself when I have time to verify since there's a lot of changes.

@@ -450,11 +452,11 @@ def load_llama_adapter(model, cfg):
task_type="CAUSAL_LM",
)

if cfg.lora_model_dir:
if cfg.peft_model_dir or cfg.lora_model_dir:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're updating to peft_model_dir, we could add a deprecation warning to validate config to reduce need for checking both like this line.

For backward compatibility, we can assign cfg.peft_model_dir = cfg.lora_model_dir if it's not None.

README.md Outdated
@@ -519,6 +519,9 @@ lora_modules_to_save:
# - lm_head
lora_out_dir:
lora_fan_in_fan_out: false
ia3_target_modules: # target modules for IA3, for llama, k, v, and down projections
ia3_feedforward_modules: # ffn modules for IA3, for llama down projection
ia3_fan_in_fan_out:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The target modules and fan in fan out feels a bit redundant since we have two similar names..

@His-Wardship
Copy link

His-Wardship commented Sep 20, 2023

Hi - I wrote the PR for 4-bit IA3 - I adjusted my own installation of Axolotl to support IA3 (I didn't submit a PR as it was a hack based on rewriting existing LoRA support, which naturally broke it for LoRA purposes, and as I have literally no training or experience in coding, I wasn't confident in adding a new functionality without breaking everything else) and found IA3 ran properly for training with no other major changes required. Comparing my hack to this PR, the changes here seem near identical - fortunately I have found IA3 and LoRA are mostly interchangeable from a code perspective, there weren't any misleading adjustments I had to make to get it working. I did not test loading, inference or merging weights in Axolotl using IA3, as I did these tasks using my own scripts or adaptions of existing scripts.

The only points I would raise are that:

  1. One of the principal benefits of IA3 is that it supports a vastly higher learning rate than LoRA (this is indicated in the original IA3 paper). This allows for high quality training to be performed with far fewer epochs. I have fine-tuned several models on complex technical documentation using LR ~0.004. This may be worth flagging to users in either a sample .yaml or in the readme itself, as missing this results in forgoing much of the improvement IA3 offers over LoRA.
  2. With respect to target modules and feedforward modules, the PEFT library already contains default settings for the majority of existing model architectures (peft.utils.other.TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING and ... TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING) which it assigns if peft.tuners.ia3._prepare_adapter_config does not receive a value. In practice, I have found that targeting all linear modules has a relatively minor effect on training speed and on the size of the adapter file (in that it doubles it from ~3MB to 6MB). It may be preferable to advise the user to allow PEFT to assign at least the feedforward module if the user is not confident in what they are selecting.
  3. At least initially, the existing llama_attn_hijack_flash script did not seem to be properly recasting the model dtype to bf16/fp16 as required by Flash Attention when using IA3. I expect this was probably my fault, though I could not identify why it wasn't working (I inserted debug lines and the function was in fact being called). This led to tensor size mismatch or just Flash Attention throwing an error and terminating. Again, as a hack, I just shoved a forced recast into axolotl.utils.models.py. I note that the llama_attn_hijack_flash script has been re-organised since I last cloned the Axolotl repo, and so expect that this problem (if it wasn't my doing) may have been solved. Still - if you get tensor mismatches on initial tests, this would be where I would recommend looking first.
  4. The current implementation of IA3 in PEFT does not support merging weights in 4-bit. I have gotten around this by loading the model in bf16 and merging the weights to that and then requantizing it. This is probably inefficient and, more importantly, for a ~34B model requires almost all my memory (24GB VRAM + 128GB CPU RAM), which I imagine is not feasible on most home PCs. I hope to put in a PR relating to merging 4-bit IA3 weights, but again, as I have no prior knowledge or experience of AI (or coding at all), I'm limited by how fast I can read the documentation.

@Napuh
Copy link
Contributor

Napuh commented Oct 13, 2023

Just a reminder, huggingface/peft#864 has been merged.

@official-elinas
Copy link

I was trying PEFT's version of IA3 back in early August and it would not work, regardless of what I tried. I'm curious to see what this will produce and will test it as soon as I can.


# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
lora_modules_to_save:
peft_modules_to_save:
# - embed_tokens
# - lm_head

# Once you complete training, the model will be saved to the following directory.
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you may have missed this variable

@@ -151,6 +153,13 @@ def flashattn_forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

if query_states.dtype == torch.float32:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment to explain this casting?

LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning")

if not cfg.load_in_8bit and cfg.adapter == "ia3":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can consolidate the checks here into one.

cfg.adapter in ["lora", "ia3"]


if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
if "lm_head" in peft_module_names: # needed for 16-bit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add a log if done so and user explicitly set this.

@creatorrr
Copy link

@winglian any updates on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants