From 41aa6f2b58266933dbf32ce2b3552f3bb5351b61 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 25 Mar 2024 11:41:16 +0100 Subject: [PATCH] reorganize imports in hf utils (#7414) --- haystack/utils/hf.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index b3afe20c17..4b017093e7 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -14,7 +14,7 @@ with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import: import torch -with LazyImport(message="Run 'pip install transformers'") as transformers_import: +with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import: from huggingface_hub import HfApi, InferenceClient, model_info from huggingface_hub.utils import RepositoryNotFoundError @@ -120,7 +120,7 @@ def resolve_hf_pipeline_kwargs( :param token: The token to use as HTTP bearer authorization for remote files. If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. """ - transformers_import.check() + huggingface_hub_import.check() token = token.resolve_value() if token else None # check if the huggingface_pipeline_kwargs contain the essential parameters @@ -173,7 +173,7 @@ def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Se :param token: The optional authentication token. :raises ValueError: If the model is not found or is not a embedding model. """ - transformers_import.check() + huggingface_hub_import.check() api = HfApi() try: @@ -202,7 +202,7 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte :param additional_accepted_params: An optional list of strings representing additional accepted parameters. :raises ValueError: If any unknown text generation parameters are provided. """ - transformers_import.check() + huggingface_hub_import.check() if kwargs: accepted_params = { @@ -219,10 +219,11 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte ) -with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import: +with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import: from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer - torch_and_transformers_import.check() + torch_import.check() + transformers_import.check() class StopWordsCriteria(StoppingCriteria): """