From 0f54593e41b2464bc7c9d7609668462e5099d399 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 27 Jul 2023 17:58:01 -0600 Subject: [PATCH] RemoteTgisTextGen: Raise if trying to run remote-only model with local TGIS Signed-off-by: Gabe Goodhart --- .../text_generation/text_generation.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation.py b/caikit_nlp/modules/text_generation/text_generation.py index 52be2802..4ea0bb12 100644 --- a/caikit_nlp/modules/text_generation/text_generation.py +++ b/caikit_nlp/modules/text_generation/text_generation.py @@ -14,7 +14,7 @@ # Standard -from typing import Iterable +from typing import Iterable, Optional import os # Third Party @@ -62,12 +62,12 @@ class TextGeneration(ModuleBase): def __init__( self, base_model_name: str, - base_model: PretrainedModelBase = None, - bos_token: str = None, - sep_token: str = None, - eos_token: str = None, - pad_token: str = None, - tgis_backend: TGISBackend = None, + base_model: Optional[PretrainedModelBase] = None, + bos_token: Optional[str] = None, + sep_token: Optional[str] = None, + eos_token: Optional[str] = None, + pad_token: Optional[str] = None, + tgis_backend: Optional[TGISBackend] = None, ): super().__init__() @@ -89,6 +89,14 @@ def __init__( self._client = tgis_backend.get_client(base_model_name) # mark that the model is loaded so that we can unload it later self._model_loaded = True + # Make sure that we either have a base model or TGIS is running as a + # remote-proxy + error.value_check( + "", + self.base_model or not tgis_backend.local_tgis, + "Cannot run model {} with TGIS locally since it has no base artifacts", + base_model_name, + ) self._bos_token = bos_token self._sep_token = sep_token @@ -122,7 +130,6 @@ def bootstrap(cls, base_model_path: str, load_backend: BackendBase = None): """ # pylint: disable=duplicate-code model_config = AutoConfig.from_pretrained(base_model_path) - resource_type = None for resource in cls.supported_resources: if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: