From 36b244db2e88f718b184d314fcbbed8801937ed4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 11:15:43 -0500 Subject: [PATCH] fix up hadrcoding, lint --- examples/tiny-llama/pretrain.yml | 6 ++++-- src/axolotl/cli/train.py | 2 -- src/axolotl/core/trainer_builder.py | 11 +++++++++-- src/axolotl/utils/collators.py | 11 +++-------- src/axolotl/utils/data.py | 29 ++++++++++++++++++++++------- 5 files changed, 38 insertions(+), 21 deletions(-) diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml index b0319a95f..dfd1bfca2 100644 --- a/examples/tiny-llama/pretrain.yml +++ b/examples/tiny-llama/pretrain.yml @@ -9,7 +9,9 @@ load_in_4bit: false strict: false max_steps: 200 -pretraining_dataset: c4 +pretraining_dataset: + path: c4 + name: en dataset_prepared_path: val_set_size: 0.0 output_dir: ./model-out @@ -45,7 +47,7 @@ xformers_attention: flash_attention: true warmup_steps: 10 -evals_per_epoch: +evals_per_epoch: eval_table_size: saves_per_epoch: 1 debug: diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 54242dd58..2248784df 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -5,7 +5,6 @@ from pathlib import Path import fire -import torch import transformers from axolotl.cli import ( @@ -20,7 +19,6 @@ from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") -# torch.set_printoptions(threshold=10000) def do_cli(config: Path = Path("examples/"), **kwargs): diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7b5fef570..dc8b1501e 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -60,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "Use quadratic warmup for cosine scheduling."}, ) + pretraining: bool = field( + default=False, + metadata={ + "help": "Indicates to trainer whether we are doing continued pretraining." + }, + ) sample_packing: bool = field( default=False, metadata={"help": "Use sample packing for efficient training."}, @@ -157,7 +163,7 @@ def create_scheduler( return self.lr_scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and False: + if self.args.sample_packing and not self.args.pretraining: return MultipackBatchSampler( RandomSampler(self.train_dataset), self.args.train_batch_size, @@ -193,7 +199,7 @@ def _get_eval_sampler( return super()._get_eval_sampler(eval_dataset) def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing and False: + if self.args.sample_packing and not self.args.pretraining: train_dataset = self.train_dataset train_dataset = train_dataset.remove_columns(["length"]) data_collator = self.data_collator @@ -768,6 +774,7 @@ def build(self, total_num_steps): training_arguments_kwargs ) training_arguments_kwargs["model_type"] = self.cfg.model_config_type + training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.neftune_noise_alpha is not None: training_arguments_kwargs[ diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index b4c4fa4df..b9c1c3b3c 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -179,6 +179,7 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: "labels": labels, } + @dataclass class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): """ @@ -191,16 +192,10 @@ def __call__(self, features, return_tensors=None): if feature == "length": continue if feature == "attention_mask": - arrays = [ - (1) * np.array(item) - for item in features[feature] - ] + arrays = [(1) * np.array(item) for item in features[feature]] chunked_data[feature] = np.concatenate(arrays) else: - arrays = [ - np.array(item) for item in features[feature] - ] + arrays = [np.array(item) for item in features[feature]] chunked_data[feature] = np.concatenate(arrays) features = [chunked_data] return super().__call__(features, return_tensors=return_tensors) - diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 199d4a64a..0de5d4ee3 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -4,9 +4,8 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple, Union, Optional +from typing import Dict, List, Tuple, Union -import numpy as np import torch from datasets import ( Dataset, @@ -42,7 +41,7 @@ SummarizeTLDRPrompter, UnsupportedPrompter, ) -from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.samplers.multipack import MultipackBatchSampler @@ -70,10 +69,17 @@ def prepare_dataset(cfg, tokenizer): tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) else: + path = cfg.pretraining_dataset + name = None + if isinstance(dict, cfg.pretraining_dataset): + path = cfg.pretraining_dataset.path + name = cfg.pretraining_dataset.name + train_dataset = load_pretraining_dataset( - cfg.pretraining_dataset, + path, tokenizer, cfg, + name=name, max_tokens=cfg.sequence_len, seed=cfg.seed or 42, ) @@ -815,13 +821,22 @@ def encode_pretraining( def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): if cfg.sample_packing: - collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens) - encode = functools.partial(encode_packed_pretraining, tokenizer, collate_fn, max_seq_length=max_tokens, batch_size=cfg.micro_batch_size) + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, return_tensors="pt", padding=True, pad_to_multiple_of=max_tokens + ) + encode = functools.partial( + encode_packed_pretraining, + tokenizer, + collate_fn, + max_seq_length=max_tokens, + batch_size=cfg.micro_batch_size, + ) + # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 else: encode = functools.partial(encode_pretraining, tokenizer, max_tokens) - dataset = load_dataset(path, streaming=True, split="train", name="en") + dataset = load_dataset(path, streaming=True, split="train", name=name) dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.map( encode,