Skip to content

Commit

Permalink
fix hardcoded data collator fix for multipack pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 5, 2024
1 parent 789c972 commit 2a49248
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def build(self, total_num_steps):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
# data_collator=self.build_collator(**data_collator_kwargs),
data_collator=self.build_collator(training_args, **data_collator_kwargs),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
Expand All @@ -836,7 +836,10 @@ def build(self, total_num_steps):

return trainer

def build_collator(self, **kwargs):
def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
if training_args.pretraining:
return None

if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)

Expand Down

0 comments on commit 2a49248

Please sign in to comment.