-
Notifications
You must be signed in to change notification settings - Fork 227
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
41 changed files
with
1,343 additions
and
266 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
# 人类偏好对齐训练文档 | ||
|
||
本文档提供了各种人类偏好对齐算法的训练脚本。若您希望深入了解更详尽的算法信息及其选择方法,请参考[文档](https://github.com/modelscope/modelscope-classroom/blob/main/LLM-tutorial/M.%E4%BA%BA%E7%B1%BB%E5%81%8F%E5%A5%BD%E5%AF%B9%E9%BD%90%E8%AE%AD%E7%BB%83.md) | ||
|
||
## 目录 | ||
- [环境准备](#环境准备) | ||
- [数据集](#数据集) | ||
- [DPO](#dpo) | ||
- [KTO](#kto) | ||
- [CPO](#cpo) | ||
- [ORPO](#orpo) | ||
- [SimPO](#simpo) | ||
|
||
## 环境准备 | ||
```bash | ||
# 设置pip全局镜像 (加速下载) | ||
pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ | ||
# 安装ms-swift | ||
git clone https://github.com/modelscope/swift.git | ||
cd swift | ||
pip install -e '.[llm]' | ||
|
||
# 环境对齐 (通常不需要运行. 如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试) | ||
pip install -r requirements/framework.txt -U | ||
pip install -r requirements/llm.txt -U | ||
``` | ||
|
||
|
||
## 数据集 | ||
|
||
人类偏好对齐训练一般需要 $(x,y_w,y_l)$ 格式的数据,其中 $x$ 表示模型输入,$y_w,y_l$ 分别表示符合人类偏好的偏好回答和不符合人类偏好的拒绝回答,比如![dpo_data](../../resources/dpo_data.png) | ||
|
||
其中KTO算法的数据比较特殊,只需要$(x,y,\text{label})$格式的数据,其中$x$表示模型输入,$y$表示模型输出,label表示回答是否符合人类偏好 | ||
比如![kto_data](../../resources/kto_data.png) | ||
|
||
KTO也可以使用第一种数据格式进行训练,训练脚本上的差异见KTO章节。 | ||
|
||
**训练提示**: | ||
- 如果用带有history的数据训练base模型,需要指定支持多轮对话的template(base模型往往不支持多轮对话),对于这种情况我们默认设置了`chatml`template,你也可以使用`--model_type` 来选择训练模型的template | ||
- 使用自定义数据集进行训练请参考[自定义与拓展](自定义与拓展.md) | ||
- 下面的训练脚本使用`--lora_target_modules ALL`来训练模型的全部线性层,你也可以设置`--lora_target_modules DEFAULT`只训练模型的QKV矩阵 | ||
|
||
## DPO | ||
[论文arvix](https://arxiv.org/abs/2305.18290) | ||
|
||
超参 | ||
- `beta`:KL正则系数,值越大表示对偏离参考模型的惩罚越大。默认为0.1 | ||
|
||
建议在开始DPO训练之前,使用偏好数据集中的偏好回答部分进行SFT训练,以确保数据符合DPO算法的分布要求。 | ||
我们也在DPO loss中混合了sft loss来稳定训练,你可以通过设置超参`sft_beta`来调整sft loss的系数,默认为0.1 | ||
|
||
训练脚本, 这里我们提供单卡/多卡device map/多卡ddp的版本,简洁起见,后续算法只给出单卡版本。 | ||
```bash | ||
# Experimental environment: A100 | ||
# Memory usage: 40G | ||
CUDA_VISIBLE_DEVICES=0 \ | ||
swift rlhf \ | ||
--rlhf_type dpo \ | ||
--model_type llama3-8b-instruct \ | ||
--beta 0.1 \ | ||
--sft_beta 0.1 \ | ||
--sft_type lora \ | ||
--dataset shareai-llama3-dpo-zh-en-emoji \ | ||
--num_train_epochs 2 \ | ||
--lora_target_modules ALL \ | ||
--gradient_checkpointing true \ | ||
--batch_size 1 \ | ||
--learning_rate 5e-5 \ | ||
--gradient_accumulation_steps 16 \ | ||
--warmup_ratio 0.03 \ | ||
--save_total_limit 2 | ||
|
||
# MP(device map) | ||
# Memory usage: 2*24G | ||
CUDA_VISIBLE_DEVICES=0,1 \ | ||
swift rlhf \ | ||
--rlhf_type dpo \ | ||
--model_type llama3-8b-instruct \ | ||
--beta 0.1 \ | ||
--sft_beta 0.1 \ | ||
--sft_type lora \ | ||
--dataset shareai-llama3-dpo-zh-en-emoji \ | ||
--num_train_epochs 2 \ | ||
--lora_target_modules ALL \ | ||
--gradient_checkpointing true \ | ||
--batch_size 1 \ | ||
--learning_rate 5e-5 \ | ||
--gradient_accumulation_steps 16 \ | ||
--warmup_ratio 0.03 \ | ||
--save_total_limit 2 | ||
|
||
# DDP + MP | ||
# Memory usage: 4*24G | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 \ | ||
NPROC_PER_NODE=2 \ | ||
swift rlhf \ | ||
--rlhf_type dpo \ | ||
--model_type llama3-8b-instruct \ | ||
--beta 0.1 \ | ||
--sft_beta 0.1 \ | ||
--sft_type lora \ | ||
--dataset shareai-llama3-dpo-zh-en-emoji \ | ||
--num_train_epochs 2 \ | ||
--lora_target_modules ALL \ | ||
--gradient_checkpointing true \ | ||
--batch_size 1 \ | ||
--learning_rate 5e-5 \ | ||
--gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ | ||
--warmup_ratio 0.03 \ | ||
--save_total_limit 2 | ||
``` | ||
|
||
训练后的模型推理和部署可以参考[LLM推理文档](./LLM推理文档.md)和[VLLM推理加速与部署文档](./VLLM推理加速与部署.md) | ||
|
||
## KTO | ||
[论文arvix](https://arxiv.org/abs/2402.01306) | ||
|
||
超参 | ||
- beta: KL正则系数,值越大表示对偏离参考模型的惩罚越大。默认为0.1 | ||
- desirable_weight :损失函数中的$\lambda_D$项,偏好回答样本的损失权重, 默认为1.0 | ||
- undesirable_weight :损失函数中的$\lambda_U$项,拒绝回答样本的损失权重,默认为1.0 | ||
|
||
用 $n_D$ 和 $n_U$ 分别表示数据集中偏好回答和拒绝回答的样本数量,对于超参 $\lambda_D$ 和 $\lambda_U$ ,作者推荐设置 $\frac{\lambda_Dn_D}{\lambda_Un_U}\in[1,\frac{4}{3}]$ | ||
|
||
训练脚本 | ||
使用 $(x,y,\text{label})$ 格式数据训练 | ||
|
||
```bash | ||
CUDA_VISIBLE_DEVICES=0 \ | ||
swift rlhf \ | ||
--rlhf_type kto \ | ||
--model_type llama3-8b-instruct \ | ||
--beta 0.1 \ | ||
--desirable_weight 1.0 \ | ||
--undesirable_weight 1.0 \ | ||
--sft_type lora \ | ||
--dataset ultrafeedback-kto \ | ||
--num_train_epochs 2 \ | ||
--lora_target_modules ALL \ | ||
--gradient_checkpointing true \ | ||
--batch_size 1 \ | ||
--learning_rate 5e-5 \ | ||
--gradient_accumulation_steps 16 \ | ||
--warmup_ratio 0.03 \ | ||
--save_total_limit 2 | ||
``` | ||
|
||
使用$(x,y_w,y_l)$格式数据训练 | ||
```bash | ||
CUDA_VISIBLE_DEVICES=0 \ | ||
swift rlhf \ | ||
--rlhf_type dpo \ | ||
--loss_type kto_pair \ | ||
--model_type llama3-8b-instruct \ | ||
--beta 0.1 \ | ||
--desirable_weight 1.0 \ | ||
--undesirable_weight 1.0 \ | ||
--sft_type lora \ | ||
--dataset shareai-llama3-dpo-zh-en-emoji \ | ||
--num_train_epochs 2 \ | ||
--lora_target_modules ALL \ | ||
--gradient_checkpointing true \ | ||
--batch_size 1 \ | ||
--learning_rate 5e-5 \ | ||
--gradient_accumulation_steps 16 \ | ||
--warmup_ratio 0.03 \ | ||
--save_total_limit 2 | ||
``` | ||
|
||
## CPO | ||
[论文arvix](https://arxiv.org/abs/2401.08417) | ||
超参 | ||
- beta:隐含奖励前的系数,默认为0.1 | ||
|
||
训练脚本 | ||
```bash | ||
CUDA_VISIBLE_DEVICES=0 \ | ||
swift rlhf \ | ||
--rlhf_type cpo \ | ||
--model_type llama3-8b-instruct \ | ||
--beta 0.1 \ | ||
--sft_type lora \ | ||
--dataset shareai-llama3-dpo-zh-en-emoji \ | ||
--num_train_epochs 2 \ | ||
--lora_target_modules ALL \ | ||
--gradient_checkpointing true \ | ||
--batch_size 1 \ | ||
--learning_rate 5e-5 \ | ||
--gradient_accumulation_steps 16 \ | ||
--warmup_ratio 0.03 \ | ||
--save_total_limit 2 | ||
``` | ||
|
||
## ORPO | ||
[论文arvix](https://arxiv.org/abs/2403.07691) | ||
|
||
超参 | ||
- lambda: Odds Ratio loss系数 | ||
|
||
注意:ORPO使用参数`--beta`传入超参`lambda` | ||
```bash | ||
CUDA_VISIBLE_DEVICES=0 \ | ||
swift rlhf \ | ||
--rlhf_type orpo \ | ||
--model_type llama3-8b-instruct \ | ||
--beta 0.1 \ | ||
--sft_type lora \ | ||
--dataset shareai-llama3-dpo-zh-en-emoji \ | ||
--num_train_epochs 2 \ | ||
--lora_target_modules ALL \ | ||
--gradient_checkpointing true \ | ||
--batch_size 1 \ | ||
--learning_rate 5e-5 \ | ||
--gradient_accumulation_steps 16 \ | ||
--warmup_ratio 0.03 \ | ||
--save_total_limit 2 | ||
``` | ||
|
||
## SimPO | ||
[论文arvix](https://arxiv.org/abs/2405.14734) | ||
超参 | ||
- beta:隐含奖励前的系数,默认为2.0 | ||
- simpo_gamma:reward margin项,默认为1.0 | ||
|
||
```bash | ||
CUDA_VISIBLE_DEVICES=0 \ | ||
swift rlhf \ | ||
--rlhf_type simpo \ | ||
--model_type llama3-8b-instruct \ | ||
--beta 2.0 \ | ||
--simpo_gamma 1.0 \ | ||
--sft_type lora \ | ||
--dataset shareai-llama3-dpo-zh-en-emoji \ | ||
--num_train_epochs 2 \ | ||
--lora_target_modules ALL \ | ||
--gradient_checkpointing true \ | ||
--batch_size 1 \ | ||
--learning_rate 5e-5 \ | ||
--gradient_accumulation_steps 16 \ | ||
--warmup_ratio 0.03 \ | ||
--save_total_limit 2 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.