Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support gemma #441

Merged
merged 6 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用


## 🎉 News
- 🔥2024.02.22: Support gemma series: gemma-2b, [gemma-2b-instruct](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/gemma_2b_instruct), gemma-7b, gemma-7b-instruct.
- 2024.02.16: Support deepseek-math series: deepseek-math-7b, deepseek-math-7b-instruct, deepseek-math-7b-chat.
- 🔥2024.02.05: Support **Qwen1.5** series, To view all supported Qwen1.5 models please check [Model List](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md#%E6%A8%A1%E5%9E%8B). The [qwen1half-7b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/qwen1half_7b_chat), [qwen1half-7b-chat-int8](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/qwen1half_7b_chat_int8) fine-tuned scripts are provided.
- 2024.02.05: Support the training of **SDXL**, **SD**, **ControlNet**, or techniques like **DreamBooth**, you can check the [training scripts](https://github.com/modelscope/swift/tree/main/examples/pytorch/sdxl/scripts) for details.
Expand Down Expand Up @@ -229,14 +230,15 @@ app_ui_main(infer_args)
- internlm-7b, internlm-7b-chat, internlm-7b-chat-8k, internlm-20b, internlm-20b-chat.
- internlm2-7b-base, internlm2-7b, internlm2-7b-sft-chat, internlm2-7b-chat, internlm2-20b-base, internlm2-20b, internlm2-20b-sft-chat, internlm2-20b-chat.
- [deepseek](https://github.com/deepseek-ai/deepseek-LLM) series: deepseek-7b, deepseek-7b-chat, deepseek-67b, deepseek-67b-chat, deepseek-moe-16b, deepseek-moe-16b-chat.
- [gemma](https://github.com/google/gemma_pytorch) series: gemma-2b, gemma-2b-instruct, gemma-7b, gemma-7b-instruct.
- [openbmb-minicpm](https://github.com/OpenBMB/mlc-MiniCPM) series: openbmb-minicpm-2b-sft-chat, openbmb-minicpm-2b-chat.
- [openbuddy](https://github.com/OpenBuddy/OpenBuddy) series: openbuddy-llama2-13b-chat, openbuddy-llama-65b-chat, openbuddy-llama2-70b-chat, openbuddy-mistral-7b-chat, openbuddy-zephyr-7b-chat, openbuddy-deepseek-67b-chat, openbuddy-mixtral-moe-7b-chat.
- [mistral](https://github.com/mistralai/mistral-src) series: mistral-7b, mistral-7b-instruct, mistral-7b-instruct-v2.
- [mixtral](https://github.com/mistralai/mistral-src) series: mixtral-moe-7b, mixtral-moe-7b-instruct.
- [baichuan](https://github.com/baichuan-inc/Baichuan2) series: baichuan-7b, baichuan-13b, baichuan-13b-chat, baichuan2-7b, baichuan2-7b-chat, baichuan2-13b, baichuan2-13b-chat, baichuan2-7b-chat-int4, baichuan2-13b-chat-int4.
- [yuan](https://github.com/IEIT-Yuan/Yuan-2.0) series: yuan2-2b-instruct, yuan2-2b-janus-instruct, yuan2-51b-instruct, yuan2-102b-instruct.
- [xverse](https://github.com/xverse-ai/XVERSE-13B) series: xverse-7b, xverse-7b-chat, xverse-13b, xverse-13b-chat, xverse-65b, xverse-65b-v2, xverse-65b-chat, xverse-13b-256k.
- [orion](https://github.com/OrionStarAI/OrionStar-Yi-34B-Chat) series: orion-14b, orion-14b-chat.
- [openbmb-minicpm](https://github.com/OpenBMB/mlc-MiniCPM) 系列: openbmb-minicpm-2b-sft-chat, openbmb-minicpm-2b-chat.
- [bluelm](https://github.com/vivo-ai-lab/BlueLM) series: bluelm-7b, bluelm-7b-chat, bluelm-7b-32k, bluelm-7b-chat-32k.
- [zephyr](https://github.com/huggingface/alignment-handbook) series: zephyr-7b-beta-chat.
- [ziya](https://github.com/IDEA-CCNL/Fengshenbang-LM) series: ziya2-13b, ziya2-13b-chat.
Expand Down
4 changes: 3 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
用户可以查看 [SWIFT官方文档](docs/source/GetStarted/快速使用.md) 来了解详细信息。

## 🎉 新闻
- 🔥2024.02.22: 支持gemma系列: gemma-2b, [gemma-2b-instruct](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/gemma_2b_instruct), gemma-7b, gemma-7b-instruct.
- 2024.02.16: 支持deepseek-math系列: deepseek-math-7b, deepseek-math-7b-instruct, deepseek-math-7b-chat.
- 🔥2024.02.05: 支持**Qwen1.5**系列模型, 支持的所有Qwen1.5系列模型请查看[模型列表](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md#%E6%A8%A1%E5%9E%8B). 提供了[qwen1half-7b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/qwen1half_7b_chat), [qwen1half-7b-chat-int8](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/qwen1half_7b_chat_int8)微调的脚本.
- 2024.02.05: 支持扩散模型如**SDXL**, **SD**, **ControlNet**的训练, 同时也支持**DreamBooth**的训练, 详情可以查看对应的[训练脚本](https://github.com/modelscope/swift/tree/main/examples/pytorch/sdxl/scripts).
Expand Down Expand Up @@ -228,14 +229,15 @@ app_ui_main(infer_args)
- internlm-7b, internlm-7b-chat, internlm-7b-chat-8k, internlm-20b, internlm-20b-chat.
- internlm2-7b-base, internlm2-7b, internlm2-7b-sft-chat, internlm2-7b-chat, internlm2-20b-base, internlm2-20b, internlm2-20b-sft-chat, internlm2-20b-chat.
- [deepseek](https://github.com/deepseek-ai/deepseek-LLM) 系列: deepseek-7b, deepseek-7b-chat, deepseek-67b, deepseek-67b-chat, deepseek-moe-16b, deepseek-moe-16b-chat.
- [gemma](https://github.com/google/gemma_pytorch) 系列: gemma-2b, gemma-2b-instruct, gemma-7b, gemma-7b-instruct.
- [openbmb-minicpm](https://github.com/OpenBMB/mlc-MiniCPM) 系列: openbmb-minicpm-2b-sft-chat, openbmb-minicpm-2b-chat.
- [openbuddy](https://github.com/OpenBuddy/OpenBuddy) 系列: openbuddy-llama2-13b-chat, openbuddy-llama-65b-chat, openbuddy-llama2-70b-chat, openbuddy-mistral-7b-chat, openbuddy-zephyr-7b-chat, openbuddy-deepseek-67b-chat, openbuddy-mixtral-moe-7b-chat.
- [mistral](https://github.com/mistralai/mistral-src) 系列: mistral-7b, mistral-7b-instruct, mistral-7b-instruct-v2.
- [mixtral](https://github.com/mistralai/mistral-src) 系列: mixtral-moe-7b, mixtral-moe-7b-instruct.
- [baichuan](https://github.com/baichuan-inc/Baichuan2) 系列: baichuan-7b, baichuan-13b, baichuan-13b-chat, baichuan2-7b, baichuan2-7b-chat, baichuan2-13b, baichuan2-13b-chat, baichuan2-7b-chat-int4, baichuan2-13b-chat-int4.
- [yuan](https://github.com/IEIT-Yuan/Yuan-2.0) 系列: yuan2-2b-instruct, yuan2-2b-janus-instruct, yuan2-51b-instruct, yuan2-102b-instruct.
- [xverse](https://github.com/xverse-ai/XVERSE-13B) 系列: xverse-7b, xverse-7b-chat, xverse-13b, xverse-13b-chat, xverse-65b, xverse-65b-v2, xverse-65b-chat, xverse-13b-256k.
- [orion](https://github.com/OrionStarAI/OrionStar-Yi-34B-Chat) 系列: orion-14b, orion-14b-chat.
- [openbmb-minicpm](https://github.com/OpenBMB/mlc-MiniCPM) 系列: openbmb-minicpm-2b-sft-chat, openbmb-minicpm-2b-chat.
- [bluelm](https://github.com/vivo-ai-lab/BlueLM) 系列: bluelm-7b, bluelm-7b-chat, bluelm-7b-32k, bluelm-7b-chat-32k.
- [zephyr](https://github.com/huggingface/alignment-handbook) 系列: zephyr-7b-beta-chat.
- [ziya](https://github.com/IDEA-CCNL/Fengshenbang-LM) 系列: ziya2-13b, ziya2-13b-chat.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/LLM/Agent微调最佳实践.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ git clone https://github.com/modelscope/swift.git
cd swift
pip install -e .[llm]

# 环境对齐 (如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
# 环境对齐 (通常不需要运行. 如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
pip install -r requirements/framework.txt -U
pip install -r requirements/llm.txt -U
```
Expand Down
2 changes: 1 addition & 1 deletion docs/source/LLM/LLM人类对齐训练文档.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ git clone https://github.com/modelscope/swift.git
cd swift
pip install -e .[llm]

# 环境对齐 (如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
# 环境对齐 (通常不需要运行. 如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
pip install -r requirements/framework.txt -U
pip install -r requirements/llm.txt -U
```
Expand Down
2 changes: 1 addition & 1 deletion docs/source/LLM/LLM微调文档.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pip install auto_gptq -U
# 如果你想要使用基于bnb的qlora训练.
pip install bitsandbytes -U

# 环境对齐 (如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
# 环境对齐 (通常不需要运行. 如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
pip install -r requirements/framework.txt -U
pip install -r requirements/llm.txt -U
```
Expand Down
2 changes: 1 addition & 1 deletion docs/source/LLM/LLM推理文档.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pip install -e .[llm]
# auto_gptq和cuda版本有对应关系,请按照`https://github.com/PanQiWei/AutoGPTQ#quick-installation`选择版本
pip install auto_gptq -U

# 环境对齐 (如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
# 环境对齐 (通常不需要运行. 如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
pip install -r requirements/framework.txt -U
pip install -r requirements/llm.txt -U
```
Expand Down
2 changes: 1 addition & 1 deletion docs/source/LLM/VLLM推理加速与部署.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pip install -e .[llm]
pip install vllm -U
pip install openai -U

# 环境对齐 (如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
# 环境对齐 (通常不需要运行. 如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
pip install -r requirements/framework.txt -U
pip install -r requirements/llm.txt -U
```
Expand Down
8 changes: 6 additions & 2 deletions docs/source/LLM/支持的模型和数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@
|deepseek-math-7b|[deepseek-ai/deepseek-math-7b-base](https://modelscope.cn/models/deepseek-ai/deepseek-math-7b-base/summary)|q_proj, k_proj, v_proj|default-generation-bos|✔|✔||
|deepseek-math-7b-instruct|[deepseek-ai/deepseek-math-7b-instruct](https://modelscope.cn/models/deepseek-ai/deepseek-math-7b-instruct/summary)|q_proj, k_proj, v_proj|deepseek|✔|✔||
|deepseek-math-7b-chat|[deepseek-ai/deepseek-math-7b-rl](https://modelscope.cn/models/deepseek-ai/deepseek-math-7b-rl/summary)|q_proj, k_proj, v_proj|deepseek|✔|✔||
|gemma-2b|[AI-ModelScope/gemma-2b](https://modelscope.cn/models/AI-ModelScope/gemma-2b/summary)|q_proj, k_proj, v_proj|default-generation-bos|✔|✔|transformers>=4.38|
|gemma-7b|[AI-ModelScope/gemma-7b](https://modelscope.cn/models/AI-ModelScope/gemma-7b/summary)|q_proj, k_proj, v_proj|default-generation-bos|✔|✔|transformers>=4.38|
|gemma-2b-instruct|[AI-ModelScope/gemma-2b-it](https://modelscope.cn/models/AI-ModelScope/gemma-2b-it/summary)|q_proj, k_proj, v_proj|gemma|✔|✔|transformers>=4.38|
|gemma-7b-instruct|[AI-ModelScope/gemma-7b-it](https://modelscope.cn/models/AI-ModelScope/gemma-7b-it/summary)|q_proj, k_proj, v_proj|gemma|✔|✔|transformers>=4.38|
|openbmb-minicpm-2b-sft-chat|[OpenBMB/MiniCPM-2B-sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32/summary)|q_proj, k_proj, v_proj|openbmb|✔|✘||
|openbmb-minicpm-2b-chat|[OpenBMB/MiniCPM-2B-dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32/summary)|q_proj, k_proj, v_proj|openbmb|✔|✘||
|openbuddy-llama2-13b-chat|[OpenBuddy/openbuddy-llama2-13b-v8.1-fp16](https://modelscope.cn/models/OpenBuddy/openbuddy-llama2-13b-v8.1-fp16/summary)|q_proj, k_proj, v_proj|openbuddy|✔|✔||
|openbuddy-llama-65b-chat|[OpenBuddy/openbuddy-llama-65b-v8-bf16](https://modelscope.cn/models/OpenBuddy/openbuddy-llama-65b-v8-bf16/summary)|q_proj, k_proj, v_proj|openbuddy|✔|✔||
|openbuddy-llama2-70b-chat|[OpenBuddy/openbuddy-llama2-70b-v10.1-bf16](https://modelscope.cn/models/OpenBuddy/openbuddy-llama2-70b-v10.1-bf16/summary)|q_proj, k_proj, v_proj|openbuddy|✔|✔||
Expand Down Expand Up @@ -155,8 +161,6 @@
|ziya2-13b-chat|[Fengshenbang/Ziya2-13B-Chat](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Chat/summary)|q_proj, k_proj, v_proj|ziya|✔|✔||
|skywork-13b|[skywork/Skywork-13B-base](https://modelscope.cn/models/skywork/Skywork-13B-base/summary)|q_proj, k_proj, v_proj|default-generation-bos|✘|✘||
|skywork-13b-chat|[skywork/Skywork-13B-chat](https://modelscope.cn/models/skywork/Skywork-13B-chat/summary)|q_proj, k_proj, v_proj|skywork|✘|✘||
|openbmb-minicpm-2b-sft-chat|[OpenBMB/MiniCPM-2B-sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32/summary)|q_proj, k_proj, v_proj|openbmb|✔|✘||
|openbmb-minicpm-2b-chat|[OpenBMB/MiniCPM-2B-dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32/summary)|q_proj, k_proj, v_proj|openbmb|✔|✘||
|zephyr-7b-beta-chat|[modelscope/zephyr-7b-beta](https://modelscope.cn/models/modelscope/zephyr-7b-beta/summary)|q_proj, k_proj, v_proj|zephyr|✔|✔|transformers>=4.34|
|polylm-13b|[damo/nlp_polylm_13b_text_generation](https://modelscope.cn/models/damo/nlp_polylm_13b_text_generation/summary)|c_attn|default-generation|✘|✘||
|seqgpt-560m|[damo/nlp_seqgpt-560m](https://modelscope.cn/models/damo/nlp_seqgpt-560m/summary)|query_key_value|default-generation|✘|✔||
Expand Down
2 changes: 1 addition & 1 deletion docs/source/LLM/自我认知微调最佳实践.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ git clone https://github.com/modelscope/swift.git
cd swift
pip install -e .[llm]

# 环境对齐 (如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
# 环境对齐 (通常不需要运行. 如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试)
pip install -r requirements/framework.txt -U
pip install -r requirements/llm.txt -U
```
Expand Down
11 changes: 11 additions & 0 deletions examples/pytorch/llm/scripts/gemma_2b_instruct/lora/infer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Experimental environment: V100, A10, 3090
CUDA_VISIBLE_DEVICES=0 \
swift infer \
--ckpt_dir "output/gemma-2b-instruct/vx_xxx/checkpoint-xxx" \
--load_dataset_config true \
--max_length 2048 \
--max_new_tokens 2048 \
--temperature 0.1 \
--top_p 0.7 \
--repetition_penalty 1. \
--do_sample true \
31 changes: 31 additions & 0 deletions examples/pytorch/llm/scripts/gemma_2b_instruct/lora/sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Experimental environment: V100, A10, 3090
# 12GB GPU memory

CUDA_VISIBLE_DEVICES=0 \
swift sft \
--model_id_or_path AI-ModelScope/gemma-2b-it \
--sft_type lora \
--tuner_backend swift \
--template_type AUTO \
--dtype AUTO \
--output_dir output \
--dataset hc3-zh \
--train_dataset_sample 5000 \
--num_train_epochs 1 \
--max_length 2048 \
--check_dataset_strategy warning \
--lora_rank 8 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--gradient_checkpointing true \
--batch_size 1 \
--weight_decay 0.01 \
--learning_rate 1e-4 \
--gradient_accumulation_steps 16 \
--max_grad_norm 0.5 \
--warmup_ratio 0.1 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 10 \
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ rouge
safetensors
tensorboard
tqdm
transformers>=4.33,<4.38
transformers>=4.33,<4.39
trl>=0.7.7
47 changes: 44 additions & 3 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ class ModelType:
deepseek_math_7b = 'deepseek-math-7b'
deepseek_math_7b_instruct = 'deepseek-math-7b-instruct'
deepseek_math_7b_chat = 'deepseek-math-7b-chat'
# gemma
gemma_2b = 'gemma-2b'
gemma_7b = 'gemma-7b'
gemma_2b_instruct = 'gemma-2b-instruct'
gemma_7b_instruct = 'gemma-7b-instruct'
# openbmb
openbmb_minicpm_2b_sft_chat = 'openbmb-minicpm-2b-sft-chat'
openbmb_minicpm_2b_chat = 'openbmb-minicpm-2b-chat'
# openbuddy
openbuddy_llama2_13b_chat = 'openbuddy-llama2-13b-chat'
openbuddy_llama2_65b_chat = 'openbuddy-llama-65b-chat'
Expand Down Expand Up @@ -201,9 +209,6 @@ class ModelType:
# skywork
skywork_13b = 'skywork-13b'
skywork_13b_chat = 'skywork-13b-chat'
# openbmb
openbmb_minicpm_2b_sft_chat = 'openbmb-minicpm-2b-sft-chat'
openbmb_minicpm_2b_chat = 'openbmb-minicpm-2b-chat'
# zephyr
zephyr_7b_beta_chat = 'zephyr-7b-beta-chat'
# other
Expand Down Expand Up @@ -719,6 +724,42 @@ def cross_entropy_forward(self, inputs: Tensor,
return model, tokenizer


@register_model(
ModelType.gemma_2b,
'AI-ModelScope/gemma-2b',
LoRATM.llama2,
TemplateType.default_generation_bos,
requires=['transformers>=4.38'],
ignore_file_pattern=[r'.+\.gguf$'],
support_flash_attn=True,
support_vllm=True)
@register_model(
ModelType.gemma_7b,
'AI-ModelScope/gemma-7b',
LoRATM.llama2,
TemplateType.default_generation_bos,
requires=['transformers>=4.38'],
ignore_file_pattern=[r'.+\.gguf$'],
support_flash_attn=True,
support_vllm=True)
@register_model(
ModelType.gemma_2b_instruct,
'AI-ModelScope/gemma-2b-it',
LoRATM.llama2,
TemplateType.gemma,
requires=['transformers>=4.38'],
ignore_file_pattern=[r'.+\.gguf$'],
support_flash_attn=True,
support_vllm=True)
@register_model(
ModelType.gemma_7b_instruct,
'AI-ModelScope/gemma-7b-it',
LoRATM.llama2,
TemplateType.gemma,
requires=['transformers>=4.38'],
ignore_file_pattern=[r'.+\.gguf$'],
support_flash_attn=True,
support_vllm=True)
@register_model(
ModelType.deepseek_math_7b_instruct,
'deepseek-ai/deepseek-math-7b-instruct',
Expand Down
8 changes: 8 additions & 0 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TemplateType:
cogagent_instruct = 'cogagent-instruct'
orion = 'orion'
openbmb = 'openbmb'
gemma = 'gemma'
# compatibility. (Deprecated)
chatml = 'chatml'

Expand Down Expand Up @@ -961,6 +962,13 @@ def data_collator(self,
TemplateType.openbmb,
Template(['<s>{{SYSTEM}}'], ['<用户>{{QUERY}}<AI>'], [], ['</s>']))

gemma_template = Template(
['<bos>'],
['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'],
['<end_of_turn>\n'], ['<end_of_turn>'], None,
['<bos><start_of_turn>system\n{{SYSTEM}}<end_of_turn>\n'])
register_template(TemplateType.gemma, gemma_template)


def get_template(
template_type: str,
Expand Down
13 changes: 8 additions & 5 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _add_patterns_to_gitattributes(

def init_hf_repo(self) -> None:
"""init ms repo. Compatible with transformers>=4.34"""
self.init_git_repo()
self.init_git_repo(at_init=True)

def init_git_repo(self, at_init: bool = False) -> None:
if not self.is_world_process_zero():
Expand Down Expand Up @@ -578,8 +578,7 @@ def _load_best_model(self):
except ValueError as e:
logger.warning(e)

def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch,
ignore_keys_for_eval):
def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
if self.control.should_log:
self.control.should_log = False
logs: Dict[str, float] = {}
Expand All @@ -595,11 +594,15 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch,
logs[k] = round(
v_scalar /
(self.state.global_step - self._globalstep_last_logged), 8)
if version.parse(
transformers.__version__) >= version.parse('4.38'):
grad_norm = args[0]
if grad_norm is not None:
logs['grad_norm'] = grad_norm
logs['learning_rate'] = self._get_learning_rate()

tr_loss -= tr_loss
self._globalstep_last_logged = self.state.global_step
self.store_flos()
self.log(logs)
super()._maybe_log_save_evaluate(tr_loss, model, trial, epoch,
ignore_keys_for_eval)
super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)
Loading