Skip to content

Commit

Permalink
add ultrachat prompt strategies (axolotl-ai-cloud#996)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 29, 2023
1 parent 41353d2 commit ba043a3
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
return strategy


def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = UltrachatShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy


def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
Expand Down Expand Up @@ -109,3 +126,17 @@ def get_conversation_thread(self, prompt):
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
]
return turns


class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps ultrachat data to sharegpt format
"""

def get_conversation_thread(self, prompt):
conversations = prompt["messages"]
role_map = {"user": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns

0 comments on commit ba043a3

Please sign in to comment.