-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
clarifai.py
95 lines (81 loc) · 2.94 KB
/
clarifai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Clarifai LM integration"""
from typing import Any, Optional
from dsp.modules.lm import LM
class ClarifaiLLM(LM):
"""Integration to call models hosted in clarifai platform.
Args:
model (str, optional): Clarifai URL of the model. Defaults to "Mistral-7B-Instruct".
api_key (Optional[str], optional): CLARIFAI_PAT token. Defaults to None.
**kwargs: Additional arguments to pass to the API provider.
Example:
import dspy
dspy.configure(lm=dspy.Clarifai(model=MODEL_URL,
api_key=CLARIFAI_PAT,
inference_params={"max_tokens":100,'temperature':0.6}))
"""
def __init__(
self,
model: str = "https://clarifai.com/mistralai/completion/models/mistral-7B-Instruct",
api_key: Optional[str] = None,
**kwargs,
):
super().__init__(model)
try:
from clarifai.client.model import Model
except ImportError as err:
raise ImportError("ClarifaiLLM requires `pip install clarifai`.") from err
self.provider = "clarifai"
self.pat = api_key
self._model = Model(url=model, pat=api_key)
self.kwargs = {"n": 1, **kwargs}
self.history: list[dict[str, Any]] = []
self.kwargs["temperature"] = (
self.kwargs["inference_params"]["temperature"]
if "inference_params" in self.kwargs
and "temperature" in self.kwargs["inference_params"]
else 0.0
)
self.kwargs["max_tokens"] = (
self.kwargs["inference_params"]["max_tokens"]
if "inference_params" in self.kwargs
and "max_tokens" in self.kwargs["inference_params"]
else 150
)
def basic_request(self, prompt, **kwargs):
params = (
self.kwargs["inference_params"] if "inference_params" in self.kwargs else {}
)
response = (
self._model.predict_by_bytes(
input_bytes=prompt.encode(encoding="utf-8"),
input_type="text",
inference_params=params,
)
.outputs[0]
.data.text.raw
)
kwargs = {**self.kwargs, **kwargs}
history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
}
self.history.append(history)
return response
def request(self, prompt: str, **kwargs):
return self.basic_request(prompt, **kwargs)
def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
):
assert only_completed, "for now"
assert return_sorted is False, "for now"
n = kwargs.pop("n", 1)
completions = []
for i in range(n):
response = self.request(prompt, **kwargs)
completions.append(response)
return completions