Skip to content

Commit

Permalink
TgisAutoFinder: Add the ability to give a specific TGIS backend to us…
Browse files Browse the repository at this point in the history
…e from the finder

This requires that the module know that it may receive an in-memory
ModuleConfig which may have a tgis_backend attribute set

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart committed Aug 8, 2023
1 parent 666066d commit 4b587c6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
9 changes: 7 additions & 2 deletions caikit_nlp/model_management/tgis_auto_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,20 @@ def find_model(

# If connection is ok, set up the module config to point to the remote
# TGIS text generation module
return ModuleConfig(
cfg = ModuleConfig(
{
"module_id": TextGenerationTGIS.MODULE_ID,
"module_class": TextGenerationTGIS.MODULE_CLASS,
"name": TextGenerationTGIS.MODULE_NAME,
"version": TextGenerationTGIS.MODULE_VERSION,
"base_model_name": model_path,
"model_name": model_path,
}
)
# Set a special indicator in the module config to use the backend that
# this finder found. This will override the backend found by the local
# initializer.
cfg.tgis_backend = self._tgis_backend
return cfg


model_finder_factory.register(TGISAutoFinder)
3 changes: 2 additions & 1 deletion caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def load(cls, model_path: str, load_backend: BackendBase) -> "TextGeneration":
error.type_check("<NLP03521359E>", TGISBackend, load_backend=load_backend)

config = ModuleConfig.load(model_path)
tgis_backend = config.tgis_backend or load_backend
artifacts_path = config.artifact_path
if artifacts_path:
model_name = os.path.join(model_path, artifacts_path)
Expand All @@ -163,7 +164,7 @@ def load(cls, model_path: str, load_backend: BackendBase) -> "TextGeneration":
sep_token=config.sep_token,
eos_token=config.eos_token,
pad_token=config.pad_token,
tgis_backend=load_backend,
tgis_backend=tgis_backend,
)

def save(self, model_path: str):
Expand Down

0 comments on commit 4b587c6

Please sign in to comment.