Skip to content

Commit

Permalink
add: training utils for kullm v3
Browse files Browse the repository at this point in the history
  • Loading branch information
metterian committed Aug 13, 2023
1 parent 04dd35f commit 42dd76d
Showing 1 changed file with 110 additions and 0 deletions.
110 changes: 110 additions & 0 deletions utils/build_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging
import os
from dataclasses import dataclass
from typing import Dict, Sequence, Union, List
import datasets
import torch
from datasets import load_dataset, concatenate_datasets
import transformers


IGNORE_INDEX = -100

logger = logging.getLogger("__name__")

# PROMPT_TEMPLATE = (
# "Below is an instruction that describes a task. "
# "Write a response that appropriately completes the request.\n\n"
# "### Instruction:\n{instruction}\n\n### Response: "
# )
PROMPT_TEMPLATE = (
"아래는 작업을 설명하는 지시문(instruction) 입니다. "
"요청을 적절하게 완료하는 응답을 작성하세요.\n\n"
"### 지시문(instruction):\n{instruction}\n\n### 응답: "
)


def build_instruction_dataset(
data_path: Union[List[str], str],
tokenizer: transformers.PreTrainedTokenizer,
max_seq_length: int,
data_cache_dir=None,
preprocessing_num_workers=None,
):
def tokenization(examples):
sources = []
targets = []
prompt = PROMPT_TEMPLATE
for instruction, input, output in zip(examples["instruction"], examples["input"], examples["output"]):
if input is not None and input != "":
instruction = instruction + "\n" + input
source = prompt.format_map({"instruction": instruction})
target = f"{output}{tokenizer.eos_token}"

sources.append(source)
targets.append(target)

tokenized_sources = tokenizer(sources, return_attention_mask=False)
tokenized_targets = tokenizer(targets, return_attention_mask=False, add_special_tokens=False)

all_input_ids = []
all_labels = []
for s, t in zip(tokenized_sources["input_ids"], tokenized_targets["input_ids"]):
input_ids = torch.LongTensor(s + t)[:max_seq_length]
labels = torch.LongTensor([IGNORE_INDEX] * len(s) + t)[:max_seq_length]
assert len(input_ids) == len(labels)
all_input_ids.append(input_ids)
all_labels.append(labels)

results = {"input_ids": all_input_ids, "labels": all_labels}
return results

logging.warning("building dataset...")
all_datasets = []

if not isinstance(data_path, (list, tuple)):
data_path = [data_path]
for file in data_path:
if data_cache_dir is None:
data_cache_dir = str(os.path.dirname(file))
cache_path = os.path.join(data_cache_dir, os.path.basename(file).split(".")[0])
os.makedirs(cache_path, exist_ok=True)
try:
processed_dataset = datasets.load_from_disk(cache_path)
logger.info(f"training datasets-{file} has been loaded from disk")
except Exception:
raw_dataset = load_dataset("json", data_files=file, cache_dir=cache_path)
tokenization_func = tokenization
tokenized_dataset = raw_dataset.map(
tokenization_func,
batched=True,
num_proc=preprocessing_num_workers,
remove_columns=["instruction", "input", "output"],
keep_in_memory=False,
desc="preprocessing on dataset",
)
processed_dataset = tokenized_dataset
processed_dataset.save_to_disk(cache_path)
processed_dataset.set_format("torch")
all_datasets.append(processed_dataset["train"])
all_datasets = concatenate_datasets(all_datasets)
return all_datasets


@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""

tokenizer: transformers.PreTrainedTokenizer

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)

0 comments on commit 42dd76d

Please sign in to comment.