Skip to content

Commit

Permalink
Support peft 0.11.0 (#953)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed May 17, 2024
1 parent 9074a2f commit 70abbe7
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 34 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Additionally, we are expanding capabilities for other modalities. Currently, we
SWIFT has rich documentations for users, please check [here](https://github.com/modelscope/swift/tree/main/docs/source_en/LLM).

## 🎉 News
- 🔥2024.05.17: Support peft=0.11.0. Meanwhile support 3 new tuners: `BOFT`, `Vera` and `Pissa`. use `--sft_type boft/vera` to use BOFT or Vera, use `--init_lora_weights pissa` with `--sft_type lora` to use Pissa.
- 2024.05.16: Supports Llava-Next (Stronger) series models. For best practice, you can refer to [here](https://github.com/modelscope/swift/tree/main/docs/source_en/Multi-Modal/llava-best-practice.md).
- 🔥2024.05.13: Support Yi-1.5 series models,use `--model_type yi-1_5-9b-chat` to begin!
- 2024.05.11: Support for qlora training and quantized inference using [hqq](https://github.com/mobiusml/hqq) and [eetq](https://github.com/NetEase-FuXi/EETQ). For more information, see the [LLM Quantization Documentation](https://github.com/modelscope/swift/tree/main/docs/source_en/LLM/LLM-quantization.md).
Expand Down
1 change: 1 addition & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ SWIFT支持近**200种LLM和MLLM**(多模态大模型)的训练、推理、
SWIFT具有丰富的文档体系,如有使用问题请请查看[这里](https://github.com/modelscope/swift/tree/main/docs/source/LLM).

## 🎉 新闻
- 🔥2024.05.17: 支持peft=0.11.0. 同时支持了三个新的tuner方法: `BOFT`, `Vera``Pissa`. 使用 `--sft_type boft/vera` 开启BOFT或者Vera, 使用 `--init_lora_weights pissa` 以及 `--sft_type lora` 来使用 Pissa.
- 2024.05.16: 支持Llava-Next (Stronger)系列模型,最佳实践可以查看[这里](https://github.com/modelscope/swift/tree/main/docs/source/Multi-Modal/llava最佳实践.md).
- 🔥2024.05.13: 支持Yi-1.5系列模型,使用`--model_type yi-1_5-9b-chat`等开始体验
- 2024.05.11: 支持使用[hqq](https://github.com/mobiusml/hqq)[eetq](https://github.com/NetEase-FuXi/EETQ)进行qlora训练和量化推理,可以查看[LLM量化文档](https://github.com/modelscope/swift/tree/main/docs/source/LLM/LLM量化文档.md)
Expand Down
18 changes: 18 additions & 0 deletions docs/source/LLM/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
- `--lora_rank`: 默认为`8`. 只有当`sft_type`指定为'lora'时才生效.
- `--lora_alpha`: 默认为`32`. 只有当`sft_type`指定为'lora'时才生效.
- `--lora_dropout_p`: 默认为`0.05`, 只有当`sft_type`指定为'lora'时才生效.
- `--init_lora_weights`: 初始化LoRA weights的方法, 可以指定为`true`, `false`, `guassian`, `pissa`, `pissa_niter_[number of iters]`, 默认值`true`.
- `--lora_bias_trainable`: 默认为`'none'`, 可以选择的值: 'none', 'all'. 如果你要将bias全都设置为可训练, 你可以设置为`'all'`.
- `--lora_modules_to_save`: 默认为`[]`. 如果你想要训练embedding, lm_head, 或者layer_norm, 你可以设置此参数, 例如: `--lora_modules_to_save EMBEDDING LN lm_head`. 如果传入`'EMBEDDING'`, 则将Embedding层添加到`lora_modules_to_save`. 如果传入`'LN'`, 则将`RMSNorm``LayerNorm`添加到`lora_modules_to_save`.
- `--lora_dtype`: 默认为`'AUTO'`, 指定lora模块的dtype类型. 如果是`AUTO`则跟随原始模块的dtype类型. 你可以选择的值: 'fp16', 'bf16', 'fp32', 'AUTO'.
Expand Down Expand Up @@ -136,6 +137,23 @@

- `--sequence_parallel_size`: 默认值`1`, 大于1时可以拆分一个sequence到多张显卡上以节省显存, 值需要设置为能被DDP数量整除

### BOFT 参数

- `--boft_block_size`: BOFT块尺寸, 默认值4.
- `--boft_block_num`: BOFT块数量, 不能和`boft_block_size`同时使用.
- `--boft_target_modules`: BOFT目标模块. 默认为`['DEFAULT']`. 如果boft_target_modules传入`'DEFAULT'` or `'AUTO'`, 则根据`model_type`查找`MODEL_MAPPING`中的boft_target_modules`(默认指定为qkv). 如果传入`'ALL'`, 则将所有的Linear层(不含head)指定为boft模块.
- `--boft_dropout`: boft的dropout值, 默认0.0.
- `--boft_modules_to_save`: 需要额外训练和存储的模块, 默认为`None`.

### Vera参数

- `--vera_rank`: Vera Attention的尺寸, 默认值256.
- `--vera_projection_prng_key`: 是否存储Vera映射矩阵, 默认为True.
- `--vera_target_modules`: Vera目标模块. 默认为`['DEFAULT']`. 如果vera_target_modules传入`'DEFAULT'` or `'AUTO'`, 则根据`model_type`查找`MODEL_MAPPING`中的vera_target_modules`(默认指定为qkv). 如果传入`'ALL'`, 则将所有的Linear层(不含head)指定为vera模块.
- `--vera_dropout`: Vera的dropout值, 默认`0.0`.
- `--vera_d_initial`: Vera的d矩阵的初始值, 默认`0.1`.
- `--vera_modules_to_save`: 需要额外训练和存储的模块, 默认为`None`.

### LoRA+微调参数

- `--lora_lr_ratio`: 默认值`None`, 建议值`10~16`, 使用lora时指定该参数即可使用lora+.
Expand Down
18 changes: 18 additions & 0 deletions docs/source_en/LLM/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
- `--lora_rank`: Default is `8`. Only takes effect when `sft_type` is 'lora'.
- `--lora_alpha`: Default is `32`. Only takes effect when `sft_type` is 'lora'.
- `--lora_dropout_p`: Default is `0.05`, only takes effect when `sft_type` is 'lora'.
- `--init_lora_weights`: Method to initialize LoRA weights, can be specified as `true`, `false`, `gaussian`, `pissa`, or `pissa_niter_[number of iters]`. Default value `true`.
- `--lora_bias_trainable`: Default is `'none'`, options: 'none', 'all'. Set to `'all'` to make all biases trainable.
- `--lora_modules_to_save`: Default is `[]`. If you want to train embedding, lm_head, or layer_norm, you can set this parameter, e.g. `--lora_modules_to_save EMBEDDING LN lm_head`. If passed `'EMBEDDING'`, Embedding layer will be added to `lora_modules_to_save`. If passed `'LN'`, `RMSNorm` and `LayerNorm` will be added to `lora_modules_to_save`.
- `--lora_dtype`: Default is `'AUTO'`, specifies dtype for lora modules. If `AUTO`, follow dtype of original module. Options: 'fp16', 'bf16', 'fp32', 'AUTO'.
Expand Down Expand Up @@ -135,6 +136,23 @@

- `--sequence_parallel_size`: Default value `1`, a positive value can be used to split a sequence to multiple GPU to reduce memory usage. The value should divide the GPU count.

### BOFT Parameters

- `--boft_block_size`: BOFT block size, default value is 4.
- `--boft_block_num`: Number of BOFT blocks, cannot be used simultaneously with `boft_block_size`.
- `--boft_target_modules`: BOFT target modules. Default is `['DEFAULT']`. If `boft_target_modules` is set to `'DEFAULT'` or `'AUTO'`, it will look up `boft_target_modules` in the `MODEL_MAPPING` based on `model_type` (default specified as qkv). If set to `'ALL'`, all Linear layers (excluding the head) will be designated as BOFT modules.
- `--boft_dropout`: Dropout value for BOFT, default is 0.0.
- `--boft_modules_to_save`: Additional modules to be trained and saved, default is `None`.

### Vera Parameters

- `--vera_rank`: Size of Vera Attention, default value is 256.
- `--vera_projection_prng_key`: Whether to store the Vera projection matrix, default is True.
- `--vera_target_modules`: Vera target modules. Default is `['DEFAULT']`. If `vera_target_modules` is set to `'DEFAULT'` or `'AUTO'`, it will look up `vera_target_modules` in the `MODEL_MAPPING` based on `model_type` (default specified as qkv). If set to `'ALL'`, all Linear layers (excluding the head) will be designated as Vera modules. Vera modules need to share a same shape.
- `--vera_dropout`: Dropout value for Vera, default is 0.0.
- `--vera_d_initial`: Initial value for Vera's d matrix, default is 0.1.
- `--vera_modules_to_save`: Additional modules to be trained and saved, default is `None`.

### LoRA+ Fine-tuning Parameters

- `--lora_lr_ratio`: Default `None`, recommended value `10~16`, specify this parameter when using lora to enable lora+.
Expand Down
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ nltk
numpy
optimum>=1.17.0
pandas
peft>=0.9.0,<0.11.0
peft>=0.11.0,<0.12.0
requests
rouge
safetensors
Expand Down
69 changes: 67 additions & 2 deletions swift/llm/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from swift.torchacc_utils import consolidate_checkpoint
from swift.trainers import TrainerCallback
from swift.tuners import (AdaLoraConfig, AdapterConfig, IA3Config, LongLoRAModelType, LoraConfig, LoRAConfig,
NEFTuneConfig, Swift)
from swift.tuners import (AdaLoraConfig, AdapterConfig, BOFTConfig, IA3Config, LongLoRAModelType, LoraConfig,
LoRAConfig, NEFTuneConfig, Swift, VeraConfig)
from swift.tuners.llamapro import LLaMAProConfig
from swift.tuners.module_mapping import MODEL_KEYS_MAPPING
from swift.utils import activate_model_parameters, freeze_model_parameters, get_logger, use_torchacc
Expand All @@ -24,6 +24,10 @@ def handle_target_modules(model, args: SftArguments) -> None:
target_modules = args.ia3_target_modules
assert len(args.ia3_feedforward_modules) > 0, ('Setting ia3_target_modules to `ALL` '
'need to pass MLP linear names to `ia3_feedforward_modules`')
elif args.sft_type == 'vera':
target_modules = args.vera_target_modules
elif args.sft_type == 'boft':
target_modules = args.boft_target_modules
else:
target_modules = args.lora_target_modules
if args.lora_use_embedding:
Expand All @@ -33,14 +37,43 @@ def handle_target_modules(model, args: SftArguments) -> None:
if args.sft_type == 'ia3':
args.ia3_target_modules = target_modules
logger.info(f'ia3_target_modules: {args.ia3_target_modules}')
elif args.sft_type == 'vera':
args.vera_target_modules = target_modules
logger.info(f'vera_target_modules: {args.ia3_target_modules}')
elif args.sft_type == 'boft':
args.boft_target_modules = target_modules
logger.info(f'boft_target_modules: {args.boft_target_modules}')
else:
args.lora_target_modules = target_modules
logger.info(f'lora_target_modules: {args.lora_target_modules}')


def handle_same_dim_target_modules(model: torch.nn.Module, config: VeraConfig):
target_modules = config.target_modules
modules_dict = {
name: module.weight.shape
for name, module in model.named_modules()
if isinstance(module, torch.nn.Linear) and any([t in name for t in target_modules])
} # only Linear for now
if len(set(modules_dict.values())) > 1:
v = [t for t in target_modules if 'v' in t]
if not v:
raise ValueError('Please manually pass in `vera_target_modules`, do not use `DEFAULT` or `ALL`,'
'because Vera need all target linears to be the same size.')
v = v[0]
shape = [shape for name, shape in modules_dict.items() if v in name][0]
names = [_name for _name, _shape in modules_dict.items() if _shape == shape]
config.target_modules = [t for t in target_modules if any([t in name for name in names])]
return config


def handle_modules_to_save(model, args: SftArguments) -> None:
if args.sft_type == 'ia3':
modules_to_save = args.ia3_modules_to_save
elif args.sft_type == 'vera':
modules_to_save = args.vera_modules_to_save
elif args.sft_type == 'boft':
modules_to_save = args.boft_modules_to_save
else:
modules_to_save = args.lora_modules_to_save
if args.lora_m2s_use_embedding:
Expand All @@ -51,6 +84,12 @@ def handle_modules_to_save(model, args: SftArguments) -> None:
if args.sft_type == 'ia3':
args.ia3_modules_to_save = modules_to_save
logger.info(f'ia3_modules_to_save: {args.ia3_modules_to_save}')
elif args.sft_type == 'vera':
args.vera_modules_to_save = modules_to_save
logger.info(f'vera_modules_to_save: {args.vera_modules_to_save}')
elif args.sft_type == 'boft':
args.boft_modules_to_save = modules_to_save
logger.info(f'boft_modules_to_save: {args.boft_modules_to_save}')
else:
args.lora_modules_to_save = modules_to_save
logger.info(f'lora_modules_to_save: {args.lora_modules_to_save}')
Expand All @@ -62,6 +101,8 @@ def prepare_model(model, args: SftArguments):
if args.resume_from_checkpoint is None:
handle_target_modules(model, args)
handle_modules_to_save(model, args)
if args.init_lora_weights and args.init_lora_weights.lower() in ('true', 'false'):
args.init_lora_weights = args.init_lora_weights.lower() in ('true', 'True')
lora_kwargs = {
'r': args.lora_rank,
'target_modules': args.lora_target_modules,
Expand All @@ -72,6 +113,7 @@ def prepare_model(model, args: SftArguments):
'use_rslora': args.use_rslora,
'use_dora': args.use_dora,
'lorap_lr_ratio': args.lora_lr_ratio,
'init_lora_weights': args.init_lora_weights,
}
if args.sft_type in ('lora', 'longlora'):
if args.lora_dtype == 'AUTO':
Expand Down Expand Up @@ -158,6 +200,29 @@ def prepare_model(model, args: SftArguments):
act_layer=args.adapter_act)
model = Swift.prepare_model(model, adapter_config)
logger.info(f'adapter_config: {adapter_config}')
elif args.sft_type == 'vera':
vera_config = VeraConfig(
r=args.vera_rank,
target_modules=args.vera_target_modules,
projection_prng_key=args.vera_projection_prng_key,
vera_dropout=args.vera_dropout,
d_initial=args.vera_d_initial,
modules_to_save=args.vera_modules_to_save,
)
vera_config = handle_same_dim_target_modules(model, vera_config)
model = Swift.prepare_model(model, vera_config)
logger.info(f'vera_config: {vera_config}')
elif args.sft_type == 'boft':
boft_config = BOFTConfig(
boft_block_size=args.boft_block_size,
boft_block_num=args.boft_block_num,
boft_n_butterfly_factor=args.boft_n_butterfly_factor,
target_modules=args.boft_target_modules,
boft_dropout=args.boft_dropout,
modules_to_save=args.boft_modules_to_save,
)
model = Swift.prepare_model(model, boft_config)
logger.info(f'boft_config: {boft_config}')
else:
if use_torchacc():
consolidate_checkpoint(args.resume_from_checkpoint, 'adapter_model')
Expand Down
29 changes: 26 additions & 3 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def is_adapter(sft_type: str) -> bool:
return sft_type in {'lora', 'longlora', 'adalora', 'ia3', 'llamapro', 'adapter'}
return sft_type in {'lora', 'longlora', 'adalora', 'ia3', 'llamapro', 'adapter', 'vera', 'boft'}


class ArgumentsBase:
Expand Down Expand Up @@ -404,7 +404,7 @@ class SftArguments(ArgumentsBase):
default=None,
metadata={'help': "Decoder Class name of model, e.g. 'QWenBlock' for QWen, 'LlamaDecoderLayer' for LLama"})

sft_type: Literal['lora', 'full', 'longlora', 'adalora', 'ia3', 'llamapro', 'adapter'] = 'lora'
sft_type: Literal['lora', 'full', 'longlora', 'adalora', 'ia3', 'llamapro', 'adapter', 'vera', 'boft'] = 'lora'
freeze_parameters: float = 0. # 0 ~ 1
additional_trainable_parameters: List[str] = field(default_factory=list)
tuner_backend: Literal['swift', 'peft', 'unsloth'] = 'peft'
Expand Down Expand Up @@ -457,6 +457,23 @@ class SftArguments(ArgumentsBase):
lora_lr_ratio: float = None
use_rslora: bool = False
use_dora: bool = False
init_lora_weights: Literal['gaussian', 'pissa', 'pissa_niter_[number of iters]', 'loftq', 'true', 'false'] = 'true'

# BOFT
boft_block_size: int = 4
boft_block_num: int = 0
boft_n_butterfly_factor: int = 1
boft_target_modules: Optional[Union[List[str], str]] = field(default_factory=lambda: ['DEFAULT'])
boft_dropout: float = 0.0
boft_modules_to_save: List[str] = field(default_factory=list)

# Vera
vera_rank: int = 256
vera_target_modules: Optional[Union[List[str], str]] = field(default_factory=lambda: ['DEFAULT'])
vera_projection_prng_key: int = 0
vera_dropout: float = 0.0
vera_d_initial: float = 0.1
vera_modules_to_save: List[str] = field(default_factory=list)

# adapter
adapter_act: str = 'gelu'
Expand Down Expand Up @@ -684,6 +701,12 @@ def __post_init__(self) -> None:
self.ia3_feedforward_modules = self._prepare_target_modules(self.ia3_feedforward_modules)
self.ia3_target_modules = self._prepare_target_modules(self.ia3_target_modules)
self.ia3_modules_to_save = self._prepare_modules_to_save(self.ia3_modules_to_save)
elif self.sft_type == 'vera':
self.vera_target_modules = self._prepare_target_modules(self.vera_target_modules)
self.vera_modules_to_save = self._prepare_modules_to_save(self.vera_modules_to_save)
elif self.sft_type == 'boft':
self.boft_target_modules = self._prepare_target_modules(self.boft_target_modules)
self.boft_modules_to_save = self._prepare_modules_to_save(self.boft_modules_to_save)
else:
self.lora_target_modules = self._prepare_target_modules(self.lora_target_modules)
self.lora_modules_to_save = self._prepare_modules_to_save(self.lora_modules_to_save)
Expand Down Expand Up @@ -926,7 +949,7 @@ class InferArguments(ArgumentsBase):
model_id_or_path: Optional[str] = None
model_revision: Optional[str] = None

sft_type: Literal['lora', 'longlora', 'full', 'adalora', 'ia3', 'llamapro'] = 'lora'
sft_type: Literal['lora', 'longlora', 'full', 'adalora', 'ia3', 'llamapro', 'vera', 'boft'] = 'lora'
template_type: str = field(
default='AUTO', metadata={'help': f"template_type choices: {list(TEMPLATE_MAPPING.keys()) + ['AUTO']}"})
infer_backend: Literal['AUTO', 'vllm', 'pt'] = 'AUTO'
Expand Down
5 changes: 4 additions & 1 deletion swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import json
import numpy as np
import peft
import safetensors
import torch
import transformers
Expand Down Expand Up @@ -250,6 +251,8 @@ def __init__(self,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
**kwargs)
if not self.label_names:
self.label_names = ['labels']
if is_quantized and use_swift:
model._hf_peft_config_loaded = _hf_peft_config_loaded

Expand Down Expand Up @@ -381,7 +384,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
from swift import SWIFT_MAPPING
addtional_module_tuners = [
name.lower() for name, (config, cls) in SWIFT_MAPPING.items() if cls.has_additional_modules()
]
] + list(peft.PEFT_TYPE_TO_CONFIG_MAPPING.keys())
if self.tokenizer is not None and sft_args.sft_type not in addtional_module_tuners:
self.tokenizer.save_pretrained(output_dir)
# training_args.bin
Expand Down
Loading

0 comments on commit 70abbe7

Please sign in to comment.