forked from caikit/caikit-nlp
-
Notifications
You must be signed in to change notification settings - Fork 3
/
load_and_run_distributed_peft.py
57 lines (48 loc) · 1.96 KB
/
load_and_run_distributed_peft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""This script loads and runs a sample PEFT model as a caikit module using
the TGIS backend.
In a nutshell, this does the following:
- Check if `text-generation-launcher` is defined; if it doesn't, assume we have a Docker image
that is (currently hardcoded to be) able to run TGIS & expose the proper ports, and patch a
wrapper script around it onto our path so that the TGIS backend falls back to leveraging it
- Load the model through caikit
- Run an inference text generation request and dump the (garbage) output back to the console
"""
# Standard
from shutil import which
import os
import subprocess
import sys
# First Party
from caikit.core.module_backend_config import _CONFIGURED_BACKENDS, configure
from caikit_tgis_backend import TGISBackend
import alog
import caikit
# Local
import caikit_nlp
alog.configure("debug4")
PREFIX_PATH = "prompt_prefixes"
has_text_gen = which("text-generation-launcher")
if not which("text-generation-launcher"):
print("Text generation server command not found; using Docker override")
this_dir = os.path.dirname(os.path.abspath(__file__))
os.environ["PATH"] += ":" + this_dir
assert (
which("text-generation-launcher") is not None
), "Text generation script not found!"
# Configure caikit to prioritize TGIS backend
_CONFIGURED_BACKENDS.clear()
# load_timeout: 320
# grpc_port: null
# http_port: 3001
# health_poll_delay: 1.0
caikit.configure(
config_dict={"module_backends": {"priority": [TGISBackend.backend_type]}}
) # should not be necessary but just in case
configure() # backend configure
# Load with TGIS backend
prefix_model_path = os.path.join(PREFIX_PATH, "sample_prompt")
my_model = caikit.load(prefix_model_path)
sample_text = "@TommyHilfiger Dramatic shopping exp. ordered 6 jeans same size (30/32) 2 fits / 2 too large / 2 too slim : same brand > different sizing"
sample_output = my_model.run(sample_text)
print("---------- Model result ----------")
print(sample_output)