Skip to content

Commit

Permalink
reorganize imports in hf utils (#7414)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Mar 25, 2024
1 parent bfd0d3e commit 41aa6f2
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions haystack/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import:
import torch

with LazyImport(message="Run 'pip install transformers'") as transformers_import:
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
from huggingface_hub import HfApi, InferenceClient, model_info
from huggingface_hub.utils import RepositoryNotFoundError

Expand Down Expand Up @@ -120,7 +120,7 @@ def resolve_hf_pipeline_kwargs(
:param token: The token to use as HTTP bearer authorization for remote files.
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
"""
transformers_import.check()
huggingface_hub_import.check()

token = token.resolve_value() if token else None
# check if the huggingface_pipeline_kwargs contain the essential parameters
Expand Down Expand Up @@ -173,7 +173,7 @@ def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Se
:param token: The optional authentication token.
:raises ValueError: If the model is not found or is not a embedding model.
"""
transformers_import.check()
huggingface_hub_import.check()

api = HfApi()
try:
Expand Down Expand Up @@ -202,7 +202,7 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
:param additional_accepted_params: An optional list of strings representing additional accepted parameters.
:raises ValueError: If any unknown text generation parameters are provided.
"""
transformers_import.check()
huggingface_hub_import.check()

if kwargs:
accepted_params = {
Expand All @@ -219,10 +219,11 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
)


with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import:
with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer

torch_and_transformers_import.check()
torch_import.check()
transformers_import.check()

class StopWordsCriteria(StoppingCriteria):
"""
Expand Down

0 comments on commit 41aa6f2

Please sign in to comment.