Skip to content

Commit

Permalink
m3e distributed learning (#70)
Browse files Browse the repository at this point in the history
* ✨feat:支持 M3E 分布式训练

* make blue happy

---------

Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin committed Jul 26, 2023
1 parent d9da840 commit 72fc675
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
8 changes: 7 additions & 1 deletion examples/finetune_jsonl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas as pd
from accelerate import DistributedDataParallelKwargs
from uniem.finetuner import FineTuner

# 读取 jsonl 文件
Expand All @@ -7,4 +8,9 @@
df = df.rename(columns={'instruction': 'text', 'output': 'text_pos'})
# 指定训练的模型为 m3e-small
finetuner = FineTuner.from_pretrained('moka-ai/m3e-small', dataset=df.to_dict('records'))
finetuner.run(epochs=1, output_dir='finetuned-model-riddle')
finetuner.run(
epochs=1,
output_dir='finetuned-model-riddle',
batch_size=32,
accelerator_kwargs={'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters=True)]},
)
16 changes: 14 additions & 2 deletions scripts/train_m3e.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import Dataset as HfDataset
from datasets import concatenate_datasets, load_from_disk
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, SequentialSampler
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup # type: ignore
from uniem.data import M3EDataset, M3EHfDatsetWithInfo, PairCollator
from uniem.model import (
Expand Down Expand Up @@ -91,16 +91,23 @@ def main(
gradient_accumulation_steps=gradient_accumulation_steps,
project_config=project_config,
log_with=['tensorboard'] if use_tensorboard else None,
dispatch_batches=True,
split_batches=True,
)
accelerator.init_trackers('m3e')
accelerator.print(f'Parameters: {locals()}')

set_seed(seed)
accelerator.print(f'Start with seed: {seed}')
accelerator.print(f'Output dir: {output_dir}')
if config_file:
accelerator.print(f'Config File: {config_file}')

# DataLoader
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
all_m3e_datasets = load_all_datasets(m3e_datasets_dir)
with accelerator.main_process_first():
all_m3e_datasets = load_all_datasets(m3e_datasets_dir)

train_dataset = M3EDataset(
all_m3e_datasets,
batch_size=batch_size,
Expand All @@ -116,7 +123,12 @@ def main(
num_workers=num_workers,
pin_memory=True,
)

# hack dataloader for distributed training
train_dataloader.__dict__['batch_size'] = batch_size
train_dataloader = accelerator.prepare(train_dataloader)
train_dataloader.__dict__['sampler'] = SequentialSampler(train_dataloader.dataset)
train_dataloader.__dict__['batch_sampler'] = None

embedder = create_uniem_embedder(
model_name_or_path=model_name_or_path,
Expand Down

0 comments on commit 72fc675

Please sign in to comment.