Skip to content

Commit

Permalink
Add epoch end callback (#48)
Browse files Browse the repository at this point in the history
* ✨feat: finetuner support epoch end callback

---------

Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin committed Jul 13, 2023
1 parent d510c75 commit 9acd3eb
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 7 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,18 @@ finetuner = FineTuner.from_pretrained('moka-ai/m3e-small', dataset=dataset)
finetuner.run(epochs=3)
```

微调的模型详见 [uniem 微调教程](https://github.com/wangyuxinwhy/uniem/blob/main/examples/finetune.ipynb) or <a target="_blank" href="https://colab.research.google.com/github/wangyuxinwhy/uniem/blob/main/examples/finetune.ipynb">
微调模型详见 [uniem 微调教程](https://github.com/wangyuxinwhy/uniem/blob/main/examples/finetune.ipynb) or <a target="_blank" href="https://colab.research.google.com/github/wangyuxinwhy/uniem/blob/main/examples/finetune.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>


如果您想要在本地运行,您需要运行如下命令,准备环境

```bash
conda create -n uniem python=3.10
pip install uniem
```

## 💯 MTEB-zh

中文 Embedding 模型缺少统一的评测标准,所以我们参考了 [MTEB](https://huggingface.co/spaces/mteb/leaderboard) ,构建了中文评测标准 MTEB-zh,目前已经对 6 种模型在各种数据集上进行了横评,详细的评测方式和代码请参考 [MTEB-zh](https://github.com/wangyuxinwhy/uniem/tree/main/mteb-zh)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "uniem"
version = "0.3.0"
version = "0.3.1"
description = "unified embedding model"
authors = ["wangyuxin <[email protected]>"]
license = "MIT"
Expand Down
6 changes: 4 additions & 2 deletions uniem/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from enum import Enum
from pathlib import Path
from typing import Iterable, Sequence, Sized, cast
from typing import Callable, Iterable, Sequence, Sized, cast

import torch
from accelerate import Accelerator
Expand Down Expand Up @@ -234,15 +234,16 @@ def run(
max_length: int = 512,
drop_last: bool = False,
shuffle: bool = False,
num_workers: int = 0,
# Trainer
epochs: int = 3,
mixed_precision: MixedPrecisionType = MixedPrecisionType.no,
gradient_accumulation_steps: int = 1,
save_on_epoch_end: bool = False,
num_max_checkpoints: int = 1,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
num_workers: int = 0,
seed: int = 42,
epoch_end_callbacks: Sequence[Callable[[Trainer], None]] | None = None,
output_dir: Path | str | None = None,
):

Expand Down Expand Up @@ -325,6 +326,7 @@ def run(
lr_scheduler=lr_scheduler,
log_interval=10,
save_on_epoch_end=save_on_epoch_end,
epoch_end_callbacks=epoch_end_callbacks,
)
accelerator.print(f'Start training for {epochs} epochs')
trainer.train()
Expand Down
4 changes: 2 additions & 2 deletions uniem/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Sized
from typing import Any, Callable, Sequence, Sized

import torch
from accelerate import Accelerator
Expand All @@ -27,7 +27,7 @@ def __init__(
lr_scheduler: LRScheduler | None = None,
log_interval: int = 50,
save_on_epoch_end: bool = True,
epoch_end_callbacks: list[Any] | None = None,
epoch_end_callbacks: Sequence[Callable[['Trainer'], None]] | None = None,
):
self.model = model
self.optimizer = optimizer
Expand Down
2 changes: 1 addition & 1 deletion uniem/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.0'
__version__ = '0.3.1'

0 comments on commit 9acd3eb

Please sign in to comment.