Skip to content

Commit

Permalink
WIP make continued pretraining work w multipack
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 5, 2024
1 parent 8ed5bcb commit da9aee1
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 20 deletions.
56 changes: 56 additions & 0 deletions examples/tiny-llama/pretrain.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0

model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: false
strict: false

max_steps: 200
pretraining_dataset: c4
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out

sequence_len: 2048
sample_packing: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch:
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
2 changes: 2 additions & 0 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path

import fire
import torch
import transformers

from axolotl.cli import (
Expand All @@ -19,6 +20,7 @@
from axolotl.train import train

LOG = logging.getLogger("axolotl.cli.train")
# torch.set_printoptions(threshold=10000)


def do_cli(config: Path = Path("examples/"), **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def create_scheduler(
return self.lr_scheduler

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing:
if self.args.sample_packing and False:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
Expand Down Expand Up @@ -193,7 +193,7 @@ def _get_eval_sampler(
return super()._get_eval_sampler(eval_dataset)

def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing:
if self.args.sample_packing and False:
train_dataset = self.train_dataset
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
Expand Down Expand Up @@ -808,7 +808,7 @@ def build(self, total_num_steps):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=self.build_collator(**data_collator_kwargs),
# data_collator=self.build_collator(**data_collator_kwargs),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
Expand Down
26 changes: 26 additions & 0 deletions src/axolotl/utils/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,29 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
"input_ids": input_ids,
"labels": labels,
}

@dataclass
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to the using the BatchSampler
"""

def __call__(self, features, return_tensors=None):
chunked_data = {}
for feature in features.keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(1) * np.array(item)
for item in features[feature]
]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item) for item in features[feature]
]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)

37 changes: 20 additions & 17 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple, Union, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -42,6 +42,7 @@
SummarizeTLDRPrompter,
UnsupportedPrompter,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.samplers.multipack import MultipackBatchSampler
Expand Down Expand Up @@ -72,6 +73,7 @@ def prepare_dataset(cfg, tokenizer):
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
tokenizer,
cfg,
max_tokens=cfg.sequence_len,
seed=cfg.seed or 42,
)
Expand Down Expand Up @@ -811,9 +813,15 @@ def encode_pretraining(
return ret


def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train")
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens)
encode = functools.partial(encode_packed_pretraining, tokenizer, collate_fn, max_seq_length=max_tokens, batch_size=cfg.micro_batch_size)
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)

dataset = load_dataset(path, streaming=True, split="train", name="en")
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.map(
encode,
Expand All @@ -828,9 +836,10 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):

def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
collate_fn,
examples: List[str],
max_seq_length: int = 8192,
sample_packing_efficiency: int = 1,
max_seq_length: int = 2048,
batch_size: int = 4,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
Expand Down Expand Up @@ -859,33 +868,27 @@ def encode_packed_pretraining(

sampler = MultipackBatchSampler(
RandomSampler(train_dataset),
batch_size=1,
batch_size=batch_size,
drop_last=True,
batch_max_len=max_seq_length,
batch_max_len=batch_size * max_seq_length,
lengths=(
train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=sample_packing_efficiency,
)

chunked_data = defaultdict(list)

for data in sampler:
features = train_dataset[data]
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)

for feature in features.keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [(1) * np.array(item) for item in features[feature]]
chunked_data[feature].append(np.concatenate(arrays))
else:
arrays = [np.array(item) for item in features[feature]]
chunked_data[feature].append(np.concatenate(arrays))

chunked_data["labels"] = chunked_data["input_ids"].copy()
chunked_data[feature].append(collated_features[feature].squeeze(0))

return chunked_data

0 comments on commit da9aee1

Please sign in to comment.