Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transformers 4.36 #218

Merged
merged 5 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix transformers==4.36
  • Loading branch information
Jintao-Huang committed Dec 14, 2023
commit 1c6eb2682fe403db1d1164e6d7ba8719f92d9868
1 change: 1 addition & 0 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def llm_sft(args: SftArguments) -> str:
save_on_each_node=args.save_on_each_node)

if args.gradient_checkpointing:
model.config.use_cache = False # fix transformers==4.36
model.enable_input_require_grads()
if is_dist():
# Compatible with https://github.com/huggingface/transformers/pull/25903
Expand Down
15 changes: 13 additions & 2 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class ModelType:
# yi
yi_6b = 'yi-6b'
yi_6b_200k = 'yi-6b-200k'
yi_6b_chat = 'yi-6b-chat'
yi_34b = 'yi-34b'
yi_34b_200k = 'yi-34b-200k'
yi_34b_chat = 'yi-34b-chat'
Expand Down Expand Up @@ -630,12 +631,22 @@ def get_model_tokenizer_with_flash_attn(model_dir: str,
if model_config is None:
model_config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=True)
_flash_attn_2_enabled = kwargs.pop('use_flash_attn', False)
model_config._flash_attn_2_enabled = _flash_attn_2_enabled
use_flash_attn = kwargs.pop('use_flash_attn', False)
if version.parse(transformers.__version__) >= version.parse('4.36'):
if use_flash_attn:
model_config._attn_implementation = 'flash_attention_2'
else:
model_config._flash_attn_2_enabled = use_flash_attn
return get_model_tokenizer_from_repo(model_dir, torch_dtype, model_kwargs,
load_model, model_config, **kwargs)


@register_model(
ModelType.yi_6b_chat,
'01ai/Yi-6B-Chat',
LoRATM.yi,
TemplateType.yi,
support_flash_attn=True)
@register_model(
ModelType.yi_34b_chat,
'01ai/Yi-34B-Chat',
Expand Down
6 changes: 4 additions & 2 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,8 @@ def _save_sft_args(self, output_dir: str) -> None:
def _save(self, output_dir: Optional[str] = None, state_dict=None):
"""Compatible with swift and peft"""
# If we are executing this function, we are the process zero, so we don't check for that.
self.state.last_model_checkpoint = output_dir
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f'Saving model checkpoint to {output_dir}')
# configuration.json
model_dir = getattr(self.model, 'model_dir', None)
if model_dir is not None:
Expand Down Expand Up @@ -421,6 +419,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
shutil.copy(src_path, dst_path)

def _save_checkpoint(self, model, trial, metrics=None):
self.state.last_model_checkpoint = os.path.join(
self.args.output_dir, f'checkpoint-{self.state.global_step}')
logger.info(
f'Saving model checkpoint to {self.state.last_model_checkpoint}')
only_save_model = self.args.only_save_model
if only_save_model:
return self._only_save_model(model, trial, metrics)
Expand Down