Skip to content

Commit

Permalink
Support Yi-6b sft (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Nov 4, 2023
1 parent 0bfc662 commit 0b3f840
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 0 deletions.
15 changes: 15 additions & 0 deletions examples/pytorch/llm/scripts/yi_6b/lora/infer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Experimental environment: A10
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python llm_infer.py \
--ckpt_dir "output/yi-6b/vx_xxx/checkpoint-xxx" \
--load_args_from_ckpt_dir true \
--eval_human false \
--max_length 256 \
--max_new_tokens 256 \
--temperature 0.9 \
--top_k 20 \
--top_p 0.9 \
--repetition_penalty 1.05 \
--do_sample true \
--merge_lora_and_save false \
36 changes: 36 additions & 0 deletions examples/pytorch/llm/scripts/yi_6b/lora/sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Experimental environment: A10
# 15GB GPU memory
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python llm_sft.py \
--model_id_or_path 01ai/Yi-6B \
--model_revision master \
--sft_type lora \
--tuner_backend swift \
--template_type default-generation \
--dtype bf16 \
--output_dir output \
--dataset dureader-robust-zh \
--train_dataset_sample -1 \
--num_train_epochs 1 \
--max_length 2048 \
--check_dataset_strategy warning \
--lora_rank 8 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--gradient_checkpointing true \
--batch_size 1 \
--weight_decay 0.01 \
--learning_rate 1e-4 \
--gradient_accumulation_steps 16 \
--max_grad_norm 0.5 \
--warmup_ratio 0.03 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 10 \
--push_to_hub false \
--hub_model_id yi-6b-qlora \
--hub_private_repo true \
--hub_token 'your-sdk-token' \
7 changes: 7 additions & 0 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class ModelType:
# other
polylm_13b = 'polylm-13b'
seqgpt_560m = 'seqgpt-560m'
yi_6b = 'yi-6b'
yi_34b = 'yi-34b'


class LoRATM(NamedTuple):
Expand All @@ -106,6 +108,7 @@ class LoRATM(NamedTuple):
xverse = ['q_proj', 'k_proj', 'v_proj']
mistral = ['q_proj', 'k_proj', 'v_proj']
ziya = ['q_proj', 'k_proj', 'v_proj']
yi = ['q_proj', 'k_proj', 'v_proj']


GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel],
Expand Down Expand Up @@ -169,6 +172,10 @@ def _register_model(
return _register_model


@register_model(ModelType.yi_34b, '01ai/Yi-34B', LoRATM.yi,
TemplateType.default_generation)
@register_model(ModelType.yi_6b, '01ai/Yi-6B', LoRATM.yi,
TemplateType.default_generation)
@register_model(ModelType.seqgpt_560m, 'damo/nlp_seqgpt-560m', LoRATM.bloom,
TemplateType.default_generation)
@register_model(ModelType.ziya2_13b_chat, 'Fengshenbang/Ziya2-13B-Chat',
Expand Down
1 change: 1 addition & 0 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class TemplateType:
xverse = 'xverse'
ziya = 'ziya'
skywork = 'skywork'
yi = 'yi'


Prompt = List[Union[str, List[Union[str, int]]]]
Expand Down

0 comments on commit 0b3f840

Please sign in to comment.