Skip to content

Commit

Permalink
支持微调ChatGLM3及其function call
Browse files Browse the repository at this point in the history
  • Loading branch information
yangjianxin1 committed Nov 20, 2023
1 parent 8d4eaf8 commit e2c0b34
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 15 deletions.
127 changes: 127 additions & 0 deletions ChatGLM3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# ChatGLM3微调介绍
之所以单独把微调ChatGLM3整理成一个文档,是因为它原生支持function call,而我们在微调的时候希望继续保持function call的功能,所以单独对其进行了适配,主要在于数据预处理。

## 数据处理
为了保持ChatGLM3原始的chat能力和function call能力,在训练时,我们与官方的数据拼接格式保持一致。
对于ChatGLM3的详细的数据处理逻辑可查看:[数据处理逻辑](https://github.com/yangjianxin1/Firefly/blob/master/component/dataset.py#L107)

## 训练数据格式
微调时,我们采用与ChatGLM3一致的数据文件格式,下面是一个示例,也可查看data/dummy_data_chatglm3.jsonl。 [官方介绍](https://github.com/THUDM/ChatGLM3/tree/main/finetune_demo)

在微调时,可以将function call与非function call的训练数据混合。如果此条数据中需要进行function call,则需要有`tools`字段,并且在conversations中需要出现`rool``tool`的数据。
当此条数据非function call数据时,则不应包含`tools`字段,并且在conversations中不应出现`rool``tool`的数据。
```json
{
"tools":[
{
"name":"get_current_weather",
"description":"Get the current weather in a given location",
"parameters":{
"type":"object",
"properties":{
"location":{
"type":"string",
"description":"The city and state, e.g. San Francisco, CA"
},
"unit":{
"type":"string"
}
},
"required":[
"location"
]
}
}
],
"conversations":[
{
"role":"user",
"content":"北京今天天气如何?"
},
{
"role":"tool",
"name":"get_current_weather",
"parameters":{
"location":"beijing"
},
"observation":{
"temperature":"20摄氏度",
"wind force":"4级"
}
},
{
"role":"assistant",
"content":"北京今天气温20摄氏度,风力4级"
},
{
"role":"user",
"content":"北京有什么旅游景点"
},
{
"role":"assistant",
"content":"北京天安门、故宫博物院、天坛、长城等都是值得游玩的景点。"
}
]
}
```

## 数据格式转换
为了兼容ChatGLM3的function call微调,我们采用了其官方的数据格式,并且与firefly的数据格式差异较大。所以在训练ChatGLM3的时候,需要手动将firefly的训练数据,进行格式转换。

我们提供了一个简单的[数据转换脚本](https://github.com/yangjianxin1/Firefly/blob/master/script/convert_data_format.py),可以将此前firefly开源的数据直接转换成ChatGLM3的训练格式。

## 微调ChatGLM3
训练配置参数均保存在[chatglm3-6b-sft-qlora.json](https://github.com/yangjianxin1/Firefly/blob/master/train_args/qlora/chatglm3-6b-sft-qlora.json)中。

单卡训练,可直接执行:
```bash
python train_qlora.py --train_args_file train_args/qlora/chatglm3-6b-sft-qlora.json
```

若是多卡,应执行:
```bash
torchrun --nproc_per_node={num_gpus} train_qlora.py --train_args_file train_args/qlora/chatglm3-6b-sft-qlora.json
```

**注意:chatglm3-6b-sft-qlora.json文件中的model_name_or_path的value值,必须要包含`chatglm3`,否则数据处理逻辑会出错。** 因为我们是根据model_name_or_path来对不同的模型进行数据处理,如下:
```python
# 加载ChatGLM2的训练集
if 'chatglm2' in args.model_name_or_path:
train_dataset = ChatGLM2SFTDataset(args.train_file, tokenizer, args.max_seq_length)
# 加载ChatGLM3的训练集
elif 'chatglm3' in args.model_name_or_path:
train_dataset = ChatGLM3SFTDataset(args.train_file, tokenizer, args.max_seq_length)
# 按照firefly格式进行拼接
else:
train_dataset = SFTDataset(args.train_file, tokenizer, args.max_seq_length)
```

## 推理
直接使用ChatGLM3官方的推理脚本即可:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel

model_name_or_path = 'THUDM/chatglm3-6b'
adapter_name_or_path = 'path-to-adapter'

# 加载base model
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map='auto',
)
# 加载adapter
if adapter_name_or_path is not None:
model = PeftModel.from_pretrained(model, adapter_name_or_path)
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
model = model.eval()
response, history = model.chat(tokenizer, "你好", history=[])
print(response)
response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
print(response)
```
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@


## News
- 🔥 支持对ChatGLM3进行指令微调,格式与原生模型保持一致,并且支持对function call能力进行微调,使用详情见[ChatGLM3微调指南](https://github.com/yangjianxin1/Firefly/blob/master/ChatGLM3.md)
- 🔥 开源[LongQLoRA](https://github.com/yangjianxin1/LongQLoRA), [技术报告](https://arxiv.org/abs/2311.04879)。可高效扩展LLama上下文长度,在单张32GB V100上将Llama2长度扩展至8k(亦可扩展至12k),仅微调1000 step,在PG19和Proof-pile数据集上的perplexity优于LongLoRA,在PG19上略胜MPT-7B-8K。
- 🔥 支持对悟道.天鹰Aquila2-34B进行指令微调。
- 🔥 开源[Firefly-LLaMA2-Chinese项目](https://github.com/yangjianxin1/Firefly-LLaMA2-Chinese)**在4*V00上进行训练**,经过中文词表扩充、增量预训练、多轮指令微调,在CMMLU上超越Linly、Yayi、FlagAlpha等,与Ziya、Chinese-Alpaca表现基本持平。该项目也支持对Baichuan、Qwen、InternLM、LLaMA、Falcon等模型进行高效增量预训练。
- 🔥 开源[firefly-baichuan2-13b](https://huggingface.co/YeungNLP/firefly-baichuan2-13b),在OpenCompass的CMMLU榜单上以56.83的分数,位列第8,比百川官方Chat模型略低1.57分。
Expand Down Expand Up @@ -306,7 +308,12 @@ QLoRA论文指出,该方法可以在一张V100上对33B的模型进行微调

我们在bloom-7b1上使用qlora,adapter的参数量约1.2亿,超过bert-base模型参数量,可以在V100上使用1024的长度进行训练。

💻 执行如下命令即可进行QLoRA微调:
💻 单卡时建议使用python命令启动脚本:
```bash
python train_qlora.py --train_args_file train_args/qlora/baichuan-7b-sft-qlora.json
```

💻 多卡时使用torchrun命令启动脚本:
```bash
torchrun --nproc_per_node={num_gpus} train_qlora.py --train_args_file train_args/qlora/baichuan-7b-sft-qlora.json
```
Expand Down
82 changes: 82 additions & 0 deletions component/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import json
from loguru import logger
import ast
import astunparse
from typing import Dict
from torch.utils.data import Dataset


Expand Down Expand Up @@ -100,3 +103,82 @@ def __getitem__(self, index):
}
return inputs


class ChatGLM3SFTDataset(SFTDataset):
def __init__(self, file, tokenizer, max_seq_length):
super(ChatGLM3SFTDataset, self).__init__(file, tokenizer, max_seq_length)
self.FUNCTION_CALL_NAME = 'tool_call'
self.FUNCTION_CALL_PREFIX = '```python\n'
self.FUNCTION_CALL_POSTFIX = '\n```'
self.TOOL_DEFINITION_PREFIX = 'Answer the following questions as best as you can. You have access to the following tools:\n'

def format_function_call(self, function_name: str, parameters: Dict[str, str]):
function_name = ast.Name(id=function_name)
keywords = [
ast.keyword(arg=arg_name, value=ast.Constant(arg_value))
for arg_name, arg_value in parameters.items()
]
func_call = ast.Call(func=function_name, args=[], keywords=keywords)
return astunparse.unparse(func_call).strip()

def __getitem__(self, index):
"""
沿袭ChatGLM3的指令微调格式,并且支持function call微调
"""
data = self.data_list[index]
data = json.loads(data)
conversations = data['conversations']

gmask_token_id = self.tokenizer.get_command('[gMASK]')
sop_token_id = self.tokenizer.get_command('sop')
input_ids = [gmask_token_id, sop_token_id] # 收集
target_mask = [0] * 2

# 此轮对话存在function call
if 'tools' in data.keys():
conversations.insert(
0, {"role": "system", "content": self.TOOL_DEFINITION_PREFIX + json.dumps(data['tools'], indent=4, ensure_ascii=False)}
)

# 拼接多轮对话
for i, conv in enumerate(conversations):
role = conv['role'].strip()
if role == 'tool':
# function call
value = self.FUNCTION_CALL_PREFIX + self.format_function_call(self.FUNCTION_CALL_NAME, conv["parameters"]) + self.FUNCTION_CALL_POSTFIX
token_ids = self.tokenizer.build_single_message("assistant", conv["name"], value) + [self.tokenizer.eos_token_id]
input_ids += token_ids
# 不计算<|assistant|>的loss
target_mask += [0] + [1] * (len(token_ids)-1)

# function call result
value = conv.get('observation', None)
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
token_ids = self.tokenizer.build_single_message("observation", "", value) + [self.tokenizer.eos_token_id]
input_ids += token_ids
target_mask += [0] * len(token_ids)
else:
token_ids = self.tokenizer.build_single_message(role, "", conv["content"]) + [self.tokenizer.eos_token_id]
input_ids += token_ids
if role == 'system' or role == 'user':
target_mask += [0] * len(token_ids)
# role=assistant
else:
# 不计算<|assistant|>的loss
target_mask += [0] + [1] * (len(token_ids)-1)

assert len(input_ids) == len(target_mask)
# 对长度进行截断
input_ids = input_ids[:self.max_seq_length]
target_mask = target_mask[:self.max_seq_length]
attention_mask = [1] * len(input_ids)
assert len(input_ids) == len(target_mask) == len(attention_mask)
inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'target_mask': target_mask
}
return inputs


Loading

0 comments on commit e2c0b34

Please sign in to comment.