Skip to content

Commit

Permalink
improve readme
Browse files Browse the repository at this point in the history
  • Loading branch information
seanzhang-zhichen committed Jul 11, 2023
1 parent 790101f commit 0da959c
Showing 1 changed file with 90 additions and 1 deletion.
91 changes: 90 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,96 @@ ChatLaw法律大模型目前开源的仅供学术参考的版本底座为姜子

## 使用 Usage

由于LLaMA权重的许可限制,该模型不能用于商业用途,请严格遵守LLaMA的使用政策。考虑到LLaMA权重的许可证限制,我们无法直接发布完整的模型权重。您可以到这里查看[权重合并说明](MERGE.md)
由于LLaMA权重的许可限制,该模型不能用于商业用途,请严格遵守LLaMA的使用政策。考虑到LLaMA权重的许可证限制,我们无法直接发布完整的模型权重。

步骤1:获取[LLaMA](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)原始权重并转成Hugging Face Transformers模型格式(简称为hf格式),可参考转换脚本[convert_llama_weights_to_hf.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py),将该代码复制粘贴保存到本地。

进入到convert_llama_weights_to_hf.py的同级目录,打开命令行执行:
```bash
python convert_llama_weights_to_hf.py --input_dir {原始 llama-13b 权重路径} --model_size 13B --output_dir 保存路径
```

例如

```bash
python convert_llama_weights_to_hf.py --input_dir /home/llama-13b --model_size 13B --output_dir /home/llama-13b-hf
```


步骤2:下载[Ziya-LLaMA-13B-v1](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1)的 delta 权重,使用如下脚本合并Ziya-LLaMA-13B-v1 的 delta 权重与 hf格式的LLaMA权重,得到完整版的Ziya-LLaMA-13B-v1模型权重。合并脚本链接:[https://github.com/IDEA-CCNL/Fengshenbang-LM/blob/main/fengshen/utils/apply_delta.py](https://github.com/IDEA-CCNL/Fengshenbang-LM/blob/main/fengshen/utils/apply_delta.py)

同理,将apply_delta.py中的代码复制粘贴到本地

进入到apply_delta.py的同级目录,打开命令行执行:

```bash
python apply_delta.py --base {hf格式的LLaMA权重路径} --target {保存路径} --delta {Ziya-LLaMA-13B-v1 的 delta 权重路径}
```

例如

```bash
python3 apply_delta.py --base /home/llama-13b-hf --target /home/Ziya-LLaMA-13B --delta /home/Ziya-LLaMA-13B-v1
```

步骤3:合并ChatLaw权重并推理

```python
import re
import torch
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer

def main():
ziya_model_path = "/home/Ziya-LLaMA-13B" # 完整的子牙模型权重路径
chatlaw_model_path = "/home/chatlaw" # chatlaw模型权重
tokenizer = LlamaTokenizer.from_pretrained(ziya_model_path)
model = LlamaForCausalLM.from_pretrained(
ziya_model_path,
torch_dtype=torch.float16,
device_map="auto",
)

model = PeftModel.from_pretrained(model, chatlaw_model_path)

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token

model.eval()

consult = "介绍一下内控行业"
prompt = f"Consult:\n{consult}\nResponse:\n"

inputs = tokenizer(prompt, return_tensors="pt")
inputs['input_ids'] = inputs['input_ids'].to(model.device)

generation_config = GenerationConfig(
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4
)

with torch.no_grad():
generation_output = model.generate(
**inputs,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=2048,
repetition_penalty=1.2,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
if search_result := re.search("Response\s*:\s*([\s\S]+?)</s>", output):
output = search_result.group(1)
print(output)

if __name__ == "__main__":
main()
```





Expand Down

0 comments on commit 0da959c

Please sign in to comment.