Skip to content

Commit

Permalink
MINOR: Support explicit specification of device (#87)
Browse files Browse the repository at this point in the history
* MINOR: Support explicit specification of device

* lint
  • Loading branch information
ncliang committed Aug 16, 2023
1 parent 13f4f41 commit c42cfde
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
22 changes: 14 additions & 8 deletions mteb-zh/mteb_zh/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,28 @@ class ModelType(str, Enum):
azure = 'azure'


def load_model(model_type: ModelType, model_id: str | None = None) -> MTEBModel:
class DeviceType(str, Enum):
cpu = 'cpu'
cuda = 'cuda'
mps = 'mps'


def load_model(model_type: ModelType, model_id: str | None = None, device: DeviceType | None = None) -> MTEBModel:
match model_type:
case ModelType.sentence_transformer:
if model_id is None:
raise ValueError('model_name must be specified for sentence_transformer')
return SentenceTransformer(model_id)
return SentenceTransformer(model_id, device=device)
case ModelType.text2vec:
try:
from text2vec import SentenceModel # type: ignore
except ImportError:
raise ImportError('text2vec is not installed, please install it with "pip install text2vec"')

if model_id is None:
return SentenceModel()
return SentenceModel(device=device)
else:
return SentenceModel(model_id)
return SentenceModel(model_id, device=device)
case ModelType.openai:
if model_id is None:
return OpenAIModel(model_name='text-embedding-ada-002')
Expand All @@ -58,14 +64,14 @@ def load_model(model_type: ModelType, model_id: str | None = None) -> MTEBModel:
return AzureModel(model_name=model_id)
case ModelType.luotuo:
if model_id is None:
return LuotuoBertModel(model_name='silk-road/luotuo-bert')
return LuotuoBertModel(model_name='silk-road/luotuo-bert', device=device)
else:
return LuotuoBertModel(model_name=model_id)
return LuotuoBertModel(model_name=model_id, device=device)
case ModelType.erlangshen:
if model_id is None:
return ErLangShenModel(model_name='IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese')
return ErLangShenModel(model_name='IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese', device=device)
else:
return ErLangShenModel(model_name=model_id)
return ErLangShenModel(model_name=model_id, device=device)
case ModelType.minimax:
if model_id is None:
return MiniMaxModel()
Expand Down
5 changes: 3 additions & 2 deletions mteb-zh/run_mteb_zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import typer
from mteb import MTEB, AbsTask
from mteb_zh.models import ModelType, load_model
from mteb_zh.models import DeviceType, ModelType, load_model
from mteb_zh.tasks import (
GubaEastmony,
IFlyTek,
Expand Down Expand Up @@ -47,9 +47,10 @@ def main(
task_type: TaskType = TaskType.Classification,
task_name: str | None = None,
output_folder: Path = Path('results'),
device: DeviceType | None = None,
):
output_folder = Path(output_folder)
model = load_model(model_type, model_id)
model = load_model(model_type, model_id, device)

if task_name:
tasks = filter_by_name(task_name)
Expand Down

0 comments on commit c42cfde

Please sign in to comment.