Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streaming multipack for pretraining dataset #959

Merged
Prev Previous commit
Next Next commit
fix up hadrcoding, lint
  • Loading branch information
winglian committed Jan 5, 2024
commit 36b244db2e88f718b184d314fcbbed8801937ed4
6 changes: 4 additions & 2 deletions examples/tiny-llama/pretrain.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ load_in_4bit: false
strict: false

max_steps: 200
pretraining_dataset: c4
pretraining_dataset:
path: c4
name: en
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out
Expand Down Expand Up @@ -45,7 +47,7 @@ xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch:
evals_per_epoch:
eval_table_size:
saves_per_epoch: 1
debug:
Expand Down
2 changes: 0 additions & 2 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path

import fire
import torch
import transformers

from axolotl.cli import (
Expand All @@ -20,7 +19,6 @@
from axolotl.train import train

LOG = logging.getLogger("axolotl.cli.train")
# torch.set_printoptions(threshold=10000)


def do_cli(config: Path = Path("examples/"), **kwargs):
Expand Down
11 changes: 9 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
Expand Down Expand Up @@ -157,7 +163,7 @@ def create_scheduler(
return self.lr_scheduler

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and False:
if self.args.sample_packing and not self.args.pretraining:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
Expand Down Expand Up @@ -193,7 +199,7 @@ def _get_eval_sampler(
return super()._get_eval_sampler(eval_dataset)

def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and False:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
Expand Down Expand Up @@ -768,6 +774,7 @@ def build(self, total_num_steps):
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)

if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
Expand Down
11 changes: 3 additions & 8 deletions src/axolotl/utils/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
"labels": labels,
}


@dataclass
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Expand All @@ -191,16 +192,10 @@ def __call__(self, features, return_tensors=None):
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(1) * np.array(item)
for item in features[feature]
]
arrays = [(1) * np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item) for item in features[feature]
]
arrays = [np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)

29 changes: 22 additions & 7 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Union, Optional
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
from datasets import (
Dataset,
Expand Down Expand Up @@ -42,7 +41,7 @@
SummarizeTLDRPrompter,
UnsupportedPrompter,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.samplers.multipack import MultipackBatchSampler
Expand Down Expand Up @@ -70,10 +69,17 @@ def prepare_dataset(cfg, tokenizer):
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
path = cfg.pretraining_dataset
name = None
if isinstance(dict, cfg.pretraining_dataset):
path = cfg.pretraining_dataset.path
name = cfg.pretraining_dataset.name

train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
path,
tokenizer,
cfg,
name=name,
max_tokens=cfg.sequence_len,
seed=cfg.seed or 42,
)
Expand Down Expand Up @@ -815,13 +821,22 @@ def encode_pretraining(

def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens)
encode = functools.partial(encode_packed_pretraining, tokenizer, collate_fn, max_seq_length=max_tokens, batch_size=cfg.micro_batch_size)
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens
)
encode = functools.partial(
encode_packed_pretraining,
tokenizer,
collate_fn,
max_seq_length=max_tokens,
batch_size=cfg.micro_batch_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)

dataset = load_dataset(path, streaming=True, split="train", name="en")
dataset = load_dataset(path, streaming=True, split="train", name=name)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.map(
encode,
Expand Down