Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streaming multipack for pretraining dataset #959

Merged
Next Next commit
[Feat] streaming multipack
  • Loading branch information
[email protected] authored and winglian committed Jan 5, 2024
commit 8ed5bcb54aacaf231fbdf592ca5f432200ad01d5
70 changes: 70 additions & 0 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import functools
import hashlib
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
from datasets import (
Dataset,
Expand All @@ -14,6 +16,7 @@
load_from_disk,
)
from huggingface_hub import hf_hub_download
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase

from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
Expand Down Expand Up @@ -41,9 +44,11 @@
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.samplers.multipack import MultipackBatchSampler
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
process_pretraining_datasets_for_packing,
)

LOG = logging.getLogger("axolotl")
Expand Down Expand Up @@ -819,3 +824,68 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
remove_columns=dataset.features.keys(),
)
return dataset


def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
examples: List[str],
max_seq_length: int = 8192,
sample_packing_efficiency: int = 1,
winglian marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_seq_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)

input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
attention_mask = [seq + [1] for seq in res["attention_mask"]]

tokenized_examples = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

train_dataset = Dataset.from_dict(tokenized_examples)
train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length
)

sampler = MultipackBatchSampler(
RandomSampler(train_dataset),
batch_size=1,
drop_last=True,
batch_max_len=max_seq_length,
lengths=(
train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=sample_packing_efficiency,
)

chunked_data = defaultdict(list)

for data in sampler:
features = train_dataset[data]

for feature in features.keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [(1) * np.array(item) for item in features[feature]]
chunked_data[feature].append(np.concatenate(arrays))
else:
arrays = [np.array(item) for item in features[feature]]
chunked_data[feature].append(np.concatenate(arrays))

chunked_data["labels"] = chunked_data["input_ids"].copy()

return chunked_data
10 changes: 10 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
return train_dataset, eval_dataset


def process_pretraining_datasets_for_packing(train_dataset, sequence_len):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)

train_dataset = train_dataset.filter(drop_long)
train_dataset = train_dataset.map(
add_position_ids,
)
return train_dataset


def calculate_total_num_steps(cfg, train_dataset, update=True):
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
Expand Down
87 changes: 87 additions & 0 deletions tests/test_packed_pretraining.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Module for testing streaming dataset sequence packing"""
import math
import unittest
from functools import partial

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.data import encode_packed_pretraining


class TestPacking(unittest.TestCase):
"""
Test class for packing streaming dataset sequences
"""

def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "[PAD]",
}
)
self.max_seq_length = 8192
self.batch_size = 6
self.sample_packing_efficiency = 1
self.data_collator_kwargs = {
"padding": True,
"pad_to_multiple_of": 64 * math.ceil(self.max_seq_length / 64),
}

def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
dataset = load_dataset(
"c4",
"en",
streaming=True,
)["train"]

encode = partial(
encode_packed_pretraining,
self.tokenizer,
max_seq_length=self.max_seq_length,
sample_packing_efficiency=self.sample_packing_efficiency,
)

dataset = dataset.map(
encode,
batched=True,
input_columns="text",
remove_columns=dataset.features.keys(),
)

data_collator_fn = DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**self.data_collator_kwargs,
)

trainer_loader = DataLoader(
dataset,
batch_size=self.batch_size,
collate_fn=data_collator_fn,
drop_last=True,
)
idx = 0
for data in trainer_loader:
if idx > 10:
break
assert data["input_ids"].shape == (self.batch_size, self.max_seq_length)
assert data["position_ids"].shape == (self.batch_size, self.max_seq_length)
assert data["labels"].shape == (self.batch_size, self.max_seq_length)
assert data["attention_mask"].shape == (
self.batch_size,
self.max_seq_length,
)
idx += 1


if __name__ == "__main__":
unittest.main()