Skip to content

Commit

Permalink
Merge branch 'main' into fix_train_service_gen
Browse files Browse the repository at this point in the history
  • Loading branch information
gkumbhat committed Aug 10, 2023
2 parents c8868b0 + 2650797 commit a7fd9ae
Show file tree
Hide file tree
Showing 12 changed files with 470 additions and 4 deletions.
2 changes: 1 addition & 1 deletion caikit_nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# Local
# Import subpackages
from . import config, data_model
from . import config, data_model, model_management
from .config import *
from .data_model import *
from .modules import *
Expand Down
16 changes: 16 additions & 0 deletions caikit_nlp/model_management/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from .tgis_auto_finder import TGISAutoFinder
142 changes: 142 additions & 0 deletions caikit_nlp/model_management/tgis_auto_finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The TGISAutoFinder implements the ModelFinder interface to provide automatic
discovery of text-generation models that can be auto-configured to run against
a remote TGIS model.
"""
# Standard
from typing import Optional

# First Party
from caikit.core import MODEL_MANAGER, error_handler
from caikit.core.model_management import ModelFinderBase, model_finder_factory
from caikit.core.modules import ModuleConfig
from caikit_tgis_backend import TGISBackend
import aconfig
import alog

# Local
from ..modules.text_generation import TextGenerationTGIS

log = alog.use_channel("TGIS_FND")
error = error_handler.get(log)


class TGISAutoFinder(ModelFinderBase):
__doc__ = __doc__

name = "TGIS-AUTO"

# Constants for the keys of the config blob
_LOCAL_INITIALIZER_NAME_KEY = "local_initializer_name"
_TGIS_BACKEND_PRIORITY_KEY = "tgis_backend_priority"

def __init__(self, config: aconfig.Config, instance_name: str = ""):
"""Initialize from the model finder factory config
Config schema:
local_initializer_name:
type: string
default: "default"
description: The name within the initializers config for the LOCAL
initializer that will hold the tgis backend to use
tgis_backend_priority:
type: integer
description: Index within the backend_priority list for the TGIS
backend to use. If not set, the first TGIS backend found will be
used.
Args:
config (aconfig.Config): The configuration blob from caikit's
model_management factory construction
instance_name (str): The name of this finder instance
"""
local_initializer_name = config.get(self._LOCAL_INITIALIZER_NAME_KEY, "default")
tgis_backend_priority = config.get(self._TGIS_BACKEND_PRIORITY_KEY)
error.type_check(
"<NLP97312902E>", str, local_initializer_name=local_initializer_name
)
error.type_check(
"<NLP97312903E>",
int,
tgis_backend_priority=tgis_backend_priority,
allow_none=True,
)

# Extract the TGIS backend instance
local_initializer = MODEL_MANAGER.get_initializer(local_initializer_name)
backends = local_initializer.backends
if tgis_backend_priority is not None:
error.value_check(
"<NLP87928813E>",
0 <= tgis_backend_priority < len(backends),
"Invalid {}: {}",
self._TGIS_BACKEND_PRIORITY_KEY,
tgis_backend_priority,
)
self._tgis_backend = backends[tgis_backend_priority]
error.value_check(
"<NLP77150201E>",
self._tgis_backend.backend_type == TGISBackend.backend_type,
"Index {} is not a TGIS backend",
tgis_backend_priority,
)
else:
tgis_backend = None
for backend in backends:
if backend.backend_type == TGISBackend.backend_type:
tgis_backend = backend
break
error.value_check(
"<NLP96294266E>",
tgis_backend is not None,
"No TGIS backend found!",
)
self._tgis_backend = tgis_backend

def find_model(
self,
model_path: str,
**kwargs,
) -> Optional[ModuleConfig]:
"""Find the model if"""

# Get a connection to this model in tgis
log.debug2("Attempting to setup TGIS client for %s", model_path)
if self._tgis_backend.get_connection(model_id=model_path) is None:
log.debug2("TGIS cannot connect to model %s", model_path)
return None

# If connection is ok, set up the module config to point to the remote
# TGIS text generation module
cfg = ModuleConfig(
{
"module_id": TextGenerationTGIS.MODULE_ID,
"module_class": TextGenerationTGIS.MODULE_CLASS,
"name": TextGenerationTGIS.MODULE_NAME,
"version": TextGenerationTGIS.MODULE_VERSION,
"model_name": model_path,
}
)
# Set a special indicator in the module config to use the backend that
# this finder found. This will override the backend found by the local
# initializer.
cfg.tgis_backend = self._tgis_backend
return cfg


model_finder_factory.register(TGISAutoFinder)
3 changes: 2 additions & 1 deletion caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def load(cls, model_path: str, load_backend: BackendBase) -> "TextGeneration":
error.type_check("<NLP03521359E>", TGISBackend, load_backend=load_backend)

config = ModuleConfig.load(model_path)
tgis_backend = config.tgis_backend or load_backend
artifacts_path = config.artifact_path
if artifacts_path:
model_name = os.path.join(model_path, artifacts_path)
Expand All @@ -163,7 +164,7 @@ def load(cls, model_path: str, load_backend: BackendBase) -> "TextGeneration":
sep_token=config.sep_token,
eos_token=config.eos_token,
pad_token=config.pad_token,
tgis_backend=load_backend,
tgis_backend=tgis_backend,
)

def save(self, model_path: str):
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ classifiers=[
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.13.0,<0.15.0",
"caikit-tgis-backend>=0.1.14,<0.2.0",
"caikit[runtime-grpc,runtime-http]>=0.15.0,<0.16.0",
"caikit-tgis-backend>=0.1.15,<0.2.0",

# TODO: loosen dependencies
"accelerate>=0.18.0",
Expand Down
7 changes: 7 additions & 0 deletions runtime_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ runtime:
size: 0 # Set to batch size for batching

model_management:
finders:
default:
type: LOCAL
remote_tgis:
type: TGIS-AUTO
config:
test_connection: true
initializers:
default:
type: LOCAL
Expand Down
29 changes: 29 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This sets up global test configs when pytest starts
"""

# Standard
import os

# First Party
import alog

# Configure logging from the environment
alog.configure(
default_level=os.environ.get("LOG_LEVEL", "off"),
filters=os.environ.get("LOG_FILTERS", "urllib3:off"),
thread_id=os.environ.get("LOG_THREAD_ID", "") == "true",
)
1 change: 1 addition & 0 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def __init__(
def get_client(self, base_model_name):
self._model_connections[base_model_name] = TGISConnection(
hostname="foo.bar",
model_id=base_model_name,
prompt_dir=self._temp_dir,
)
return StubTGISClient(base_model_name)
Expand Down
Empty file.
Loading

0 comments on commit a7fd9ae

Please sign in to comment.