Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore!: Rename model_name to model in the Cohere integration #222

Merged
merged 9 commits into from
Jan 17, 2024
Next Next commit
rename model_name to model in cohere chat generator
  • Loading branch information
ZanSara committed Jan 16, 2024
commit b30ce6b8bcff5fe74e624b03b92e8e2f9ce10ae8
16 changes: 8 additions & 8 deletions integrations/cohere/src/cohere_haystack/chat/chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class CohereChatGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "command",
model: str = "command",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -37,7 +37,7 @@ def __init__(
Initialize the CohereChatGenerator instance.

:param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var.
:param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly,
:param model: The name of the model to use. Available models are: [command, command-light, command-nightly,
command-nightly-light]. Defaults to "command".
:param streaming_callback: A callback function to be called with the streaming response. Defaults to None.
:param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai".
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
if generation_kwargs is None:
generation_kwargs = {}
self.api_key = api_key
self.model_name = model_name
self.model = model
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.generation_kwargs = generation_kwargs
Expand All @@ -93,7 +93,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -103,7 +103,7 @@ def to_dict(self) -> Dict[str, Any]:
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs,
Expand Down Expand Up @@ -147,7 +147,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
chat_history = [self._message_to_dict(m) for m in messages[:-1]]
response = self.client.chat(
message=messages[-1].content,
model=self.model_name,
model=self.model,
stream=self.streaming_callback is not None,
chat_history=chat_history,
**generation_kwargs,
Expand All @@ -160,7 +160,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
chat_message = ChatMessage.from_assistant(content=response.texts)
chat_message.meta.update(
{
"model": self.model_name,
"model": self.model,
"usage": response.token_count,
"index": 0,
"finish_reason": response.finish_reason,
Expand Down Expand Up @@ -193,7 +193,7 @@ def _build_message(self, cohere_response):
message = ChatMessage.from_assistant(content=content)
message.meta.update(
{
"model": self.model_name,
"model": self.model,
"usage": cohere_response.token_count,
"index": 0,
"finish_reason": None,
Expand Down