Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow "weight: 0" in messages to mask them #1703

Merged

Conversation

DavidFarago
Copy link
Contributor

@DavidFarago DavidFarago commented Jun 11, 2024

Allow in message objects the additional key weight, which can be set to 0 (or 1) to cause that message to be masked out (or left unmasked) for training (similar to 1).

Description

Extend src/axolotl/prompters.py::_build_result to return the turns with weights as additional tuple element. Do this in axolotl directly instead of modifying fastchat.conversation's Conversation.

Extend src/axolotl/prompt_tokenizers.py::tokenize_prompt to mask out tokens when weight is set to 0.

Motivation and Context

This is helpful for training the model to be robust and capable of error recovery upon a bad assistant message. A missing weight key defaults to weight 1, to guarantee downward compatibility.

How has this been tested?

Extend tests/prompt_strategies/test_sharegpt.py to contain messages with weight keys.

@DavidFarago DavidFarago force-pushed the mr-weight-on-main branch 2 times, most recently from ccebf54 to d9bbf5d Compare June 12, 2024 13:32
Allow in message objects the additional key `weight`, which can be set
to 0 (or 1) to cause that message to be masked out (or left unmasked)
for training (similar to [1]). This is helpful for training the model to be robust and
capable of error recovery upon a bad assistant message.
A missing `weight` key defaults to weight 1, to guarantee downward compatibility.

Extend `src/axolotl/prompters.py::_build_result` and
`src/axolotl/prompt_strategies/sharegpt.py::SimpleShareGPTPromptTokenizingStrategy::get_conversation_thread`
to return the turns with weights as additional tuple element.
Do this in axolotl directly instead of modifying `fastchat.conversation`'s `Conversation`.

Extend `src/axolotl/prompt_tokenizers.py::tokenize_prompt` to mask out tokens when weight is set to 0.

Extend `tests/prompt_strategies/test_sharegpt.py` with four test cases that contain messages with `weight` keys.
Switch names `test_w_train_on_input` and `test_no_train_on_input`.

[1]: https://github.com/mistralai/mistral-finetune
Copy link
Collaborator

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

@DavidFarago
Copy link
Contributor Author

you are welcome, @winglian -- thank you for reviewing.

Will you merge this PR or give me write access to this repository?

@winglian winglian merged commit 559562d into OpenAccess-AI-Collective:main Jun 20, 2024
8 checks passed
@winglian
Copy link
Collaborator

merged. thanks @DavidFarago !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants