-
Notifications
You must be signed in to change notification settings - Fork 32
/
hf_utils.py
32 lines (27 loc) · 1 KB
/
hf_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from .typing import assert_type
from transformers import AutoConfig, PreTrainedModel
from typing import Type
import transformers
def get_model_class(model_str: str) -> Type[PreTrainedModel]:
"""Get the appropriate model class for a model string."""
model_cfg = AutoConfig.from_pretrained(model_str)
archs = assert_type(list, model_cfg.architectures)
# Ordered by preference
suffixes = [
# Fine-tuned for classification
"SequenceClassification",
# Encoder-decoder models
"ConditionalGeneration",
# Autoregressive models
"CausalLM",
"LMHeadModel",
]
for suffix in suffixes:
# Check if any of the architectures in the config end with the suffix.
# If so, return the corresponding model class.
for arch_str in archs:
if arch_str.endswith(suffix):
return getattr(transformers, arch_str)
raise ValueError(
f"'{model_str}' does not have any supported architectures: {archs}"
)