Skip to content

Commit

Permalink
fix(dspy): return type of basic request as list not str
Browse files Browse the repository at this point in the history
  • Loading branch information
Anindyadeep committed Jun 19, 2024
1 parent a2112a0 commit 5d98bf2
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions dsp/modules/premai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,12 @@ def get_premai_api_key(api_key: Optional[str] = None) -> str:


class PremAI(LM):
"""Wrapper around Prem AI's API."""

def __init__(
self,
project_id: int,
model: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs: dict,
) -> None:
"""Parameters
Expand All @@ -63,6 +62,7 @@ def __init__(
api_key: Optional[str]
Prem AI API key, to connect with the API. If not provided then it will check from env var by the name
PREMAI_API_KEY
kwargs: Optional[dict] For any additional paramters
"""
self.model = "default" if model is None else model
super().__init__(self.model)
Expand All @@ -77,6 +77,7 @@ def __init__(
self.client = premai.Prem(api_key=api_key)
self.provider = "premai"
self.history: list[dict[str, Any]] = []
self.kwargs = kwargs if kwargs else {}

@property
def _default_params(self) -> dict[str, Any]:
Expand All @@ -103,6 +104,8 @@ def _get_all_kwargs(self, **kwargs) -> dict[str, Any]:
"seed",
]
keys_to_remove = []
kwargs = {**kwargs, **self.kwargs}

for key in kwargs:
if key in kwargs_to_ignore:
warnings.warn(f"WARNING: Parameter {key} is not supported in kwargs.", stacklevel=2)
Expand All @@ -117,10 +120,9 @@ def _get_all_kwargs(self, **kwargs) -> dict[str, Any]:
all_kwargs.pop(key, None)
return all_kwargs

def basic_request(self, prompt, **kwargs) -> str:
def basic_request(self, prompt, **kwargs) -> list[str]:
"""Handles retrieval of completions from Prem AI whilst handling API errors."""
all_kwargs = self._get_all_kwargs(**kwargs)

if "template_id" not in all_kwargs:
messages = [{"role": "user", "content": prompt}]
else:
Expand Down Expand Up @@ -155,9 +157,6 @@ def basic_request(self, prompt, **kwargs) -> str:
messages=messages,
**all_kwargs,
)
if not response.choices:
raise premai_api_error("ChatResponse must have at least one candidate")

content = response.choices[0].message.content
if not content:
raise premai_api_error("ChatResponse is none")
Expand All @@ -175,8 +174,7 @@ def basic_request(self, prompt, **kwargs) -> str:
"kwargs": kwargs,
},
)

return output_text
return [output_text]

@backoff.on_exception(
backoff.expo,
Expand Down

0 comments on commit 5d98bf2

Please sign in to comment.