Skip to content

Commit

Permalink
Allow "weight: 0" in messages to mask them
Browse files Browse the repository at this point in the history
Allow in message objects the additional key `weight`, which can be set
to 0 (or 1) to cause that message to be masked out (or left unmasked)
for training (similar to [1]). This is helpful for training the model to be robust and
capable of error recovery upon a bad assistant message.
A missing `weight` key defaults to weight 1, to guarantee downward compatibility.

Extend `tests/prompt_strategies/test_sharegpt.py` to contain messages with `weight` keys.

Extend `src/axolotl/prompters.py::_build_result` and
`src/axolotl/prompt_strategies/sharegpt.py::SimpleShareGPTPromptTokenizingStrategy::get_conversation_thread`
to return the turns with weights as additional tuple element.
Do this in axolotl directly instead of modifying `fastchat.conversation`'s `Conversation`.

Extend `src/axolotl/prompt_tokenizers.py::tokenize_prompt` to mask out tokens when weight is set to 0.

[1]: https://github.com/mistralai/mistral-finetune
  • Loading branch information
DavidFarago committed Jun 12, 2024
1 parent 5783839 commit d9bbf5d
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 12 deletions.
3 changes: 3 additions & 0 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def get_conversation_thread(self, prompt):
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
),
"value": t[value_key],
"weight": 1
if "weight" not in t or t["weight"] is None
else t["weight"],
}
for t in conversations
]
Expand Down
15 changes: 12 additions & 3 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,11 @@ def tokenize_prompt(self, prompt):
LOG.warning(f"expected tuple, got {part}")
continue

role, content = part
if len(part) <= 2:
role, content = part
weight = 1
else:
role, content, weight = part

# Uses "in" because role contains extra characters
input_turn = any(r.lower() in role.lower() for r in input_roles)
Expand All @@ -403,7 +407,7 @@ def tokenize_prompt(self, prompt):
add_eos_token=False,
strip_bos_token=True,
)
if self.train_on_inputs:
if self.train_on_inputs and weight == 1:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
Expand Down Expand Up @@ -439,13 +443,18 @@ def tokenize_prompt(self, prompt):
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
if weight == 0:
# everything from this is masked out from the labels
# (role is masked out too because it makes no sense if contents is masked out)
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])

elif empty_role:
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
if self.train_on_inputs:
if self.train_on_inputs and weight == 1:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
Expand Down
24 changes: 22 additions & 2 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def _build_result(self, source):

conv = self._conversation.copy()

original_source = source.copy()
# Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system":
conv.set_system_message(source[0]["value"])
Expand Down Expand Up @@ -360,8 +361,27 @@ def _build_result(self, source):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")

conv.append_message(role, sentence["value"])

return conv.get_turns()
turns = list(conv.get_turns())
original_source_length = len(original_source)
assert len(turns) in [
original_source_length - 1,
original_source_length,
original_source_length + 1,
]
if len(turns) == original_source_length + 1:
original_source = [{"weight": None}] + original_source
elif len(turns) == original_source_length - 1:
original_source = original_source[1:]
return [
(*turn, weight)
for turn, weight in zip(
turns,
[
1 if "weight" not in e or e["weight"] is None else e["weight"]
for e in original_source
],
)
]

def build_prompt(self, source) -> Generator[str, None, None]:
turns = self._build_result(source)
Expand Down
42 changes: 35 additions & 7 deletions tests/prompt_strategies/test_sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,22 @@ def fixture_sharegpt_dataset():
{
"from": "human",
"value": "hello",
"weight": 1,
},
{
"from": "gpt",
"value": "hello",
"weight": 0,
},
{
"from": "human",
"value": "rehello",
"weight": 0,
},
{
"from": "gpt",
"value": "rehello",
"weight": 1,
},
{
"from": "human",
Expand All @@ -45,6 +57,7 @@ def fixture_sharegpt_dataset():
{
"from": "gpt",
"value": "goodbye",
"weight": 0,
},
]
}
Expand Down Expand Up @@ -156,6 +169,10 @@ def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 11310, 4896, 128009,
128006, 78191, 128007,
271, 11310, 4896, 128009,
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
Expand Down Expand Up @@ -192,12 +209,14 @@ def test_no_double_im_end(self, sharegpt_dataset, tokenizer):
32001, 1587, 13, 25997, 32000, 28705, 13, # system
32001, 2188, 13, 21558, 32000, 28705, 13, # human
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
32001, 2188, 13, 267, 21558, 32000, 28705, 13, # human
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
]
# fmt: on

def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
Expand All @@ -219,13 +238,17 @@ def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
-100, # bos
-100, -100, -100, -100, -100, -100, -100, # system
-100, -100, -100, -100, -100, -100, -100, # human
-100, -100, 13, 21558, 32000, 28705, 13, # gpt
# -100, -100, 13, 21558, 32000, 28705, 13, # gpt
-100, -100, -100, -100, -100, -100, -100, # gpt with weight zero
-100, -100, -100, -100, -100, -100, -100, -100, # human
-100, -100, 13, 267, 21558, 32000, 28705, 13, # gpt
-100, -100, -100, -100, -100, -100, -100, -100, # human
-100, -100, 13, 12684, 17664, 32000, 28705, 13, # gpt
# -100, -100, 13, 12684, 17664, 32000, 28705, 13, # gpt
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight zero
]
# fmt: on

def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
Expand All @@ -247,9 +270,14 @@ def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
1, # bos
32001, 1587, 13, 25997, 32000, 28705, 13, # system
32001, 2188, 13, 21558, 32000, 28705, 13, # human
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
# 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
-100, -100, -100, -100, -100, -100, -100, # gpt with weight 0
# 32001, 2188, 13, 267, 21558, 32000, 28705, 13, # human
-100, -100, -100, -100, -100, -100, -100, -100, # human with weight 0
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
# 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight 0
]
# fmt: on

Expand Down

0 comments on commit d9bbf5d

Please sign in to comment.