Skip to content

Commit

Permalink
fix up hadrcoding, lint
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 5, 2024
1 parent da9aee1 commit 36b244d
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 21 deletions.
6 changes: 4 additions & 2 deletions examples/tiny-llama/pretrain.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ load_in_4bit: false
strict: false

max_steps: 200
pretraining_dataset: c4
pretraining_dataset:
path: c4
name: en
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out
Expand Down Expand Up @@ -45,7 +47,7 @@ xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch:
evals_per_epoch:
eval_table_size:
saves_per_epoch: 1
debug:
Expand Down
2 changes: 0 additions & 2 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path

import fire
import torch
import transformers

from axolotl.cli import (
Expand All @@ -20,7 +19,6 @@
from axolotl.train import train

LOG = logging.getLogger("axolotl.cli.train")
# torch.set_printoptions(threshold=10000)


def do_cli(config: Path = Path("examples/"), **kwargs):
Expand Down
11 changes: 9 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
Expand Down Expand Up @@ -157,7 +163,7 @@ def create_scheduler(
return self.lr_scheduler

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and False:
if self.args.sample_packing and not self.args.pretraining:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
Expand Down Expand Up @@ -193,7 +199,7 @@ def _get_eval_sampler(
return super()._get_eval_sampler(eval_dataset)

def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and False:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
Expand Down Expand Up @@ -768,6 +774,7 @@ def build(self, total_num_steps):
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)

if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
Expand Down
11 changes: 3 additions & 8 deletions src/axolotl/utils/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
"labels": labels,
}


@dataclass
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Expand All @@ -191,16 +192,10 @@ def __call__(self, features, return_tensors=None):
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(1) * np.array(item)
for item in features[feature]
]
arrays = [(1) * np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item) for item in features[feature]
]
arrays = [np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)

29 changes: 22 additions & 7 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Union, Optional
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
from datasets import (
Dataset,
Expand Down Expand Up @@ -42,7 +41,7 @@
SummarizeTLDRPrompter,
UnsupportedPrompter,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.samplers.multipack import MultipackBatchSampler
Expand Down Expand Up @@ -70,10 +69,17 @@ def prepare_dataset(cfg, tokenizer):
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
path = cfg.pretraining_dataset
name = None
if isinstance(dict, cfg.pretraining_dataset):
path = cfg.pretraining_dataset.path
name = cfg.pretraining_dataset.name

train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
path,
tokenizer,
cfg,
name=name,
max_tokens=cfg.sequence_len,
seed=cfg.seed or 42,
)
Expand Down Expand Up @@ -815,13 +821,22 @@ def encode_pretraining(

def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens)
encode = functools.partial(encode_packed_pretraining, tokenizer, collate_fn, max_seq_length=max_tokens, batch_size=cfg.micro_batch_size)
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens
)
encode = functools.partial(
encode_packed_pretraining,
tokenizer,
collate_fn,
max_seq_length=max_tokens,
batch_size=cfg.micro_batch_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)

dataset = load_dataset(path, streaming=True, split="train", name="en")
dataset = load_dataset(path, streaming=True, split="train", name=name)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.map(
encode,
Expand Down

0 comments on commit 36b244d

Please sign in to comment.