From b30ce6b8bcff5fe74e624b03b92e8e2f9ce10ae8 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 16 Jan 2024 17:15:40 +0100 Subject: [PATCH 1/9] rename model_name to model in cohere chat generator --- .../src/cohere_haystack/chat/chat_generator.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/chat/chat_generator.py b/integrations/cohere/src/cohere_haystack/chat/chat_generator.py index 1e5aa0e42..0ff29ce14 100644 --- a/integrations/cohere/src/cohere_haystack/chat/chat_generator.py +++ b/integrations/cohere/src/cohere_haystack/chat/chat_generator.py @@ -27,7 +27,7 @@ class CohereChatGenerator: def __init__( self, api_key: Optional[str] = None, - model_name: str = "command", + model: str = "command", streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, api_base_url: Optional[str] = None, generation_kwargs: Optional[Dict[str, Any]] = None, @@ -37,7 +37,7 @@ def __init__( Initialize the CohereChatGenerator instance. :param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var. - :param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly, + :param model: The name of the model to use. Available models are: [command, command-light, command-nightly, command-nightly-light]. Defaults to "command". :param streaming_callback: A callback function to be called with the streaming response. Defaults to None. :param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai". @@ -82,7 +82,7 @@ def __init__( if generation_kwargs is None: generation_kwargs = {} self.api_key = api_key - self.model_name = model_name + self.model = model self.streaming_callback = streaming_callback self.api_base_url = api_base_url self.generation_kwargs = generation_kwargs @@ -93,7 +93,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]: """ Data that is sent to Posthog for usage analytics. """ - return {"model": self.model_name} + return {"model": self.model} def to_dict(self) -> Dict[str, Any]: """ @@ -103,7 +103,7 @@ def to_dict(self) -> Dict[str, Any]: callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None return default_to_dict( self, - model_name=self.model_name, + model=self.model, streaming_callback=callback_name, api_base_url=self.api_base_url, generation_kwargs=self.generation_kwargs, @@ -147,7 +147,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, chat_history = [self._message_to_dict(m) for m in messages[:-1]] response = self.client.chat( message=messages[-1].content, - model=self.model_name, + model=self.model, stream=self.streaming_callback is not None, chat_history=chat_history, **generation_kwargs, @@ -160,7 +160,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, chat_message = ChatMessage.from_assistant(content=response.texts) chat_message.meta.update( { - "model": self.model_name, + "model": self.model, "usage": response.token_count, "index": 0, "finish_reason": response.finish_reason, @@ -193,7 +193,7 @@ def _build_message(self, cohere_response): message = ChatMessage.from_assistant(content=content) message.meta.update( { - "model": self.model_name, + "model": self.model, "usage": cohere_response.token_count, "index": 0, "finish_reason": None, From 9ead87791f96a471b1acad7b69783a1fbf0b0040 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 16 Jan 2024 17:16:04 +0100 Subject: [PATCH 2/9] fix chat generator tests --- .../tests/test_cohere_chat_generator.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 20a02863a..bc6286d8e 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -57,7 +57,7 @@ class TestCohereChatGenerator: def test_init_default(self): component = CohereChatGenerator(api_key="test-api-key") assert component.api_key == "test-api-key" - assert component.model_name == "command" + assert component.model == "command" assert component.streaming_callback is None assert component.api_base_url == cohere.COHERE_API_URL assert not component.generation_kwargs @@ -72,13 +72,13 @@ def test_init_fail_wo_api_key(self, monkeypatch): def test_init_with_parameters(self): component = CohereChatGenerator( api_key="test-api-key", - model_name="command-nightly", + model="command-nightly", streaming_callback=default_streaming_callback, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) assert component.api_key == "test-api-key" - assert component.model_name == "command-nightly" + assert component.model == "command-nightly" assert component.streaming_callback is default_streaming_callback assert component.api_base_url == "test-base-url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} @@ -90,7 +90,7 @@ def test_to_dict_default(self): assert data == { "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model_name": "command", + "model": "command", "streaming_callback": None, "api_base_url": "https://api.cohere.ai", "generation_kwargs": {}, @@ -101,7 +101,7 @@ def test_to_dict_default(self): def test_to_dict_with_parameters(self): component = CohereChatGenerator( api_key="test-api-key", - model_name="command-nightly", + model="command-nightly", streaming_callback=default_streaming_callback, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, @@ -110,7 +110,7 @@ def test_to_dict_with_parameters(self): assert data == { "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model_name": "command-nightly", + "model": "command-nightly", "streaming_callback": "haystack.components.generators.utils.default_streaming_callback", "api_base_url": "test-base-url", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -121,7 +121,7 @@ def test_to_dict_with_parameters(self): def test_to_dict_with_lambda_streaming_callback(self): component = CohereChatGenerator( api_key="test-api-key", - model_name="command", + model="command", streaming_callback=lambda x: x, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, @@ -130,7 +130,7 @@ def test_to_dict_with_lambda_streaming_callback(self): assert data == { "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model_name": "command", + "model": "command", "api_base_url": "test-base-url", "streaming_callback": "tests.test_cohere_chat_generator.", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -143,14 +143,14 @@ def test_from_dict(self, monkeypatch): data = { "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model_name": "command", + "model": "command", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.default_streaming_callback", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } component = CohereChatGenerator.from_dict(data) - assert component.model_name == "command" + assert component.model == "command" assert component.streaming_callback is default_streaming_callback assert component.api_base_url == "test-base-url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} @@ -161,7 +161,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): data = { "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model_name": "command", + "model": "command", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.default_streaming_callback", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -261,7 +261,7 @@ def test_live_run(self): @pytest.mark.integration def test_live_run_wrong_model(self, chat_messages): component = CohereChatGenerator( - model_name="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY") + model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY") ) with pytest.raises(cohere.CohereAPIError, match="finetuned model something-obviously-wrong is not valid"): component.run(chat_messages) From 008d1aaf71fdfe469a8b8c0d3d32bbcb498e802a Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 16 Jan 2024 18:04:22 +0100 Subject: [PATCH 3/9] rename model_name to model in cohere generator --- integrations/cohere/src/cohere_haystack/generator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/generator.py b/integrations/cohere/src/cohere_haystack/generator.py index 66c80afa4..9917f17ea 100644 --- a/integrations/cohere/src/cohere_haystack/generator.py +++ b/integrations/cohere/src/cohere_haystack/generator.py @@ -32,7 +32,7 @@ class CohereGenerator: def __init__( self, api_key: Optional[str] = None, - model_name: str = "command", + model: str = "command", streaming_callback: Optional[Callable] = None, api_base_url: Optional[str] = None, **kwargs, @@ -41,7 +41,7 @@ def __init__( Instantiates a `CohereGenerator` component. :param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var. - :param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly, + :param model: The name of the model to use. Available models are: [command, command-light, command-nightly, command-nightly-light]. Defaults to "command". :param streaming_callback: A callback function to be called with the streaming response. Defaults to None. :param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai". @@ -86,7 +86,7 @@ def __init__( api_base_url = COHERE_API_URL self.api_key = api_key - self.model_name = model_name + self.model = model self.streaming_callback = streaming_callback self.api_base_url = api_base_url self.model_parameters = kwargs @@ -107,7 +107,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, - model_name=self.model_name, + model=self.model, streaming_callback=callback_name, api_base_url=self.api_base_url, **self.model_parameters, @@ -142,7 +142,7 @@ def run(self, prompt: str): :param prompt: The prompt to be sent to the generative model. """ response = self.client.generate( - model=self.model_name, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters + model=self.model, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters ) if self.streaming_callback: metadata_dict: Dict[str, Any] = {} From 1231144c4459e62f92e64650b512562b812ef09d Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 16 Jan 2024 18:04:52 +0100 Subject: [PATCH 4/9] fix generator tests --- .../cohere/tests/test_cohere_generators.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/integrations/cohere/tests/test_cohere_generators.py b/integrations/cohere/tests/test_cohere_generators.py index ec8027d96..f22b38843 100644 --- a/integrations/cohere/tests/test_cohere_generators.py +++ b/integrations/cohere/tests/test_cohere_generators.py @@ -23,7 +23,7 @@ class TestCohereGenerator: def test_init_default(self): component = CohereGenerator(api_key="test-api-key") assert component.api_key == "test-api-key" - assert component.model_name == "command" + assert component.model == "command" assert component.streaming_callback is None assert component.api_base_url == COHERE_API_URL assert component.model_parameters == {} @@ -32,14 +32,14 @@ def test_init_with_parameters(self): callback = lambda x: x # noqa: E731 component = CohereGenerator( api_key="test-api-key", - model_name="command-light", + model="command-light", max_tokens=10, some_test_param="test-params", streaming_callback=callback, api_base_url="test-base-url", ) assert component.api_key == "test-api-key" - assert component.model_name == "command-light" + assert component.model == "command-light" assert component.streaming_callback == callback assert component.api_base_url == "test-base-url" assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} @@ -50,7 +50,7 @@ def test_to_dict_default(self): assert data == { "type": "cohere_haystack.generator.CohereGenerator", "init_parameters": { - "model_name": "command", + "model": "command", "streaming_callback": None, "api_base_url": COHERE_API_URL, }, @@ -59,7 +59,7 @@ def test_to_dict_default(self): def test_to_dict_with_parameters(self): component = CohereGenerator( api_key="test-api-key", - model_name="command-light", + model="command-light", max_tokens=10, some_test_param="test-params", streaming_callback=default_streaming_callback, @@ -69,7 +69,7 @@ def test_to_dict_with_parameters(self): assert data == { "type": "cohere_haystack.generator.CohereGenerator", "init_parameters": { - "model_name": "command-light", + "model": "command-light", "max_tokens": 10, "some_test_param": "test-params", "api_base_url": "test-base-url", @@ -80,7 +80,7 @@ def test_to_dict_with_parameters(self): def test_to_dict_with_lambda_streaming_callback(self): component = CohereGenerator( api_key="test-api-key", - model_name="command", + model="command", max_tokens=10, some_test_param="test-params", streaming_callback=lambda x: x, @@ -90,7 +90,7 @@ def test_to_dict_with_lambda_streaming_callback(self): assert data == { "type": "cohere_haystack.generator.CohereGenerator", "init_parameters": { - "model_name": "command", + "model": "command", "streaming_callback": "tests.test_cohere_generators.", "api_base_url": "test-base-url", "max_tokens": 10, @@ -103,7 +103,7 @@ def test_from_dict(self, monkeypatch): data = { "type": "cohere_haystack.generator.CohereGenerator", "init_parameters": { - "model_name": "command", + "model": "command", "max_tokens": 10, "some_test_param": "test-params", "api_base_url": "test-base-url", @@ -112,7 +112,7 @@ def test_from_dict(self, monkeypatch): } component: CohereGenerator = CohereGenerator.from_dict(data) assert component.api_key == "test-key" - assert component.model_name == "command" + assert component.model == "command" assert component.streaming_callback == default_streaming_callback assert component.api_base_url == "test-base-url" assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} @@ -144,10 +144,10 @@ def test_cohere_generator_run(self): reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration - def test_cohere_generator_run_wrong_model_name(self): + def test_cohere_generator_run_wrong_model(self): import cohere - component = CohereGenerator(model_name="something-obviously-wrong") + component = CohereGenerator(model="something-obviously-wrong") with pytest.raises( cohere.CohereAPIError, match="model not found, make sure the correct model ID was used and that you have access to the model.", From e9487826cb04fc3aa5a3a211b5f576024680fd25 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 16 Jan 2024 18:06:09 +0100 Subject: [PATCH 5/9] rename model_name to model in cohere document embedder --- .../cohere_haystack/embedders/document_embedder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py index bc0b9381d..151c4f794 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py @@ -36,7 +36,7 @@ class CohereDocumentEmbedder: def __init__( self, api_key: Optional[str] = None, - model_name: str = "embed-english-v2.0", + model: str = "embed-english-v2.0", input_type: str = "search_document", api_base_url: str = COHERE_API_URL, truncate: str = "END", @@ -53,7 +53,7 @@ def __init__( :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). - :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: + :param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: `"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`, `"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`, `"embed-multilingual-v2.0"`. This list of all supported models can be found in the @@ -88,7 +88,7 @@ def __init__( raise ValueError(msg) self.api_key = api_key - self.model_name = model_name + self.model = model self.input_type = input_type self.api_base_url = api_base_url self.truncate = truncate @@ -106,7 +106,7 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - model_name=self.model_name, + model=self.model, input_type=self.input_type, api_base_url=self.api_base_url, truncate=self.truncate, @@ -160,7 +160,7 @@ def run(self, documents: List[Document]): self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) all_embeddings, metadata = asyncio.run( - get_async_response(cohere_client, texts_to_embed, self.model_name, self.input_type, self.truncate) + get_async_response(cohere_client, texts_to_embed, self.model, self.input_type, self.truncate) ) else: cohere_client = Client( @@ -169,7 +169,7 @@ def run(self, documents: List[Document]): all_embeddings, metadata = get_response( cohere_client, texts_to_embed, - self.model_name, + self.model, self.input_type, self.truncate, self.batch_size, From 346fa1d0f360e5daef932fa149541433a4bb6004 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 16 Jan 2024 18:06:27 +0100 Subject: [PATCH 6/9] fix document embedder tests --- integrations/cohere/tests/test_document_embedder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index 02dbd4c3e..c9770737e 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -16,7 +16,7 @@ class TestCohereDocumentEmbedder: def test_init_default(self): embedder = CohereDocumentEmbedder(api_key="test-api-key") assert embedder.api_key == "test-api-key" - assert embedder.model_name == "embed-english-v2.0" + assert embedder.model == "embed-english-v2.0" assert embedder.input_type == "search_document" assert embedder.api_base_url == COHERE_API_URL assert embedder.truncate == "END" @@ -31,7 +31,7 @@ def test_init_default(self): def test_init_with_parameters(self): embedder = CohereDocumentEmbedder( api_key="test-api-key", - model_name="embed-multilingual-v2.0", + model="embed-multilingual-v2.0", input_type="search_query", api_base_url="https://custom-api-base-url.com", truncate="START", @@ -44,7 +44,7 @@ def test_init_with_parameters(self): embedding_separator="-", ) assert embedder.api_key == "test-api-key" - assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.model == "embed-multilingual-v2.0" assert embedder.input_type == "search_query" assert embedder.api_base_url == "https://custom-api-base-url.com" assert embedder.truncate == "START" @@ -62,7 +62,7 @@ def test_to_dict(self): assert component_dict == { "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", "init_parameters": { - "model_name": "embed-english-v2.0", + "model": "embed-english-v2.0", "input_type": "search_document", "api_base_url": COHERE_API_URL, "truncate": "END", @@ -79,7 +79,7 @@ def test_to_dict(self): def test_to_dict_with_custom_init_parameters(self): embedder_component = CohereDocumentEmbedder( api_key="test-api-key", - model_name="embed-multilingual-v2.0", + model="embed-multilingual-v2.0", input_type="search_query", api_base_url="https://custom-api-base-url.com", truncate="START", @@ -95,7 +95,7 @@ def test_to_dict_with_custom_init_parameters(self): assert component_dict == { "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", "init_parameters": { - "model_name": "embed-multilingual-v2.0", + "model": "embed-multilingual-v2.0", "input_type": "search_query", "api_base_url": "https://custom-api-base-url.com", "truncate": "START", From a0ce7702b4cfe025a98951d517bb95d96dbde105 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 16 Jan 2024 18:16:28 +0100 Subject: [PATCH 7/9] rename model_name to model in cohere text embedder --- .../src/cohere_haystack/embedders/text_embedder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py index 4ba8acd47..bfef97dc3 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py @@ -34,7 +34,7 @@ class CohereTextEmbedder: def __init__( self, api_key: Optional[str] = None, - model_name: str = "embed-english-v2.0", + model: str = "embed-english-v2.0", input_type: str = "search_query", api_base_url: str = COHERE_API_URL, truncate: str = "END", @@ -47,7 +47,7 @@ def __init__( :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). - :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: + :param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: `"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`, `"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`, `"embed-multilingual-v2.0"`. This list of all supported models can be found in the @@ -77,7 +77,7 @@ def __init__( raise ValueError(msg) self.api_key = api_key - self.model_name = model_name + self.model = model self.input_type = input_type self.api_base_url = api_base_url self.truncate = truncate @@ -91,7 +91,7 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - model_name=self.model_name, + model=self.model, input_type=self.input_type, api_base_url=self.api_base_url, truncate=self.truncate, @@ -117,12 +117,12 @@ def run(self, text: str): self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) embedding, metadata = asyncio.run( - get_async_response(cohere_client, [text], self.model_name, self.input_type, self.truncate) + get_async_response(cohere_client, [text], self.model, self.input_type, self.truncate) ) else: cohere_client = Client( self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) - embedding, metadata = get_response(cohere_client, [text], self.model_name, self.input_type, self.truncate) + embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate) return {"embedding": embedding[0], "meta": metadata} From 314a293819227012375bba2e6a4980595b675b16 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 16 Jan 2024 18:16:45 +0100 Subject: [PATCH 8/9] fix text embedder tests --- integrations/cohere/tests/test_text_embedder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index 46f77cb43..7e91b4812 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -19,7 +19,7 @@ def test_init_default(self): embedder = CohereTextEmbedder(api_key="test-api-key") assert embedder.api_key == "test-api-key" - assert embedder.model_name == "embed-english-v2.0" + assert embedder.model == "embed-english-v2.0" assert embedder.input_type == "search_query" assert embedder.api_base_url == COHERE_API_URL assert embedder.truncate == "END" @@ -33,7 +33,7 @@ def test_init_with_parameters(self): """ embedder = CohereTextEmbedder( api_key="test-api-key", - model_name="embed-multilingual-v2.0", + model="embed-multilingual-v2.0", input_type="classification", api_base_url="https://custom-api-base-url.com", truncate="START", @@ -42,7 +42,7 @@ def test_init_with_parameters(self): timeout=60, ) assert embedder.api_key == "test-api-key" - assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.model == "embed-multilingual-v2.0" assert embedder.input_type == "classification" assert embedder.api_base_url == "https://custom-api-base-url.com" assert embedder.truncate == "START" @@ -59,7 +59,7 @@ def test_to_dict(self): assert component_dict == { "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", "init_parameters": { - "model_name": "embed-english-v2.0", + "model": "embed-english-v2.0", "input_type": "search_query", "api_base_url": COHERE_API_URL, "truncate": "END", @@ -75,7 +75,7 @@ def test_to_dict_with_custom_init_parameters(self): """ embedder_component = CohereTextEmbedder( api_key="test-api-key", - model_name="embed-multilingual-v2.0", + model="embed-multilingual-v2.0", input_type="classification", api_base_url="https://custom-api-base-url.com", truncate="START", @@ -87,7 +87,7 @@ def test_to_dict_with_custom_init_parameters(self): assert component_dict == { "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", "init_parameters": { - "model_name": "embed-multilingual-v2.0", + "model": "embed-multilingual-v2.0", "input_type": "classification", "api_base_url": "https://custom-api-base-url.com", "truncate": "START", From b0c5df16c8c048e58fb9bfe1ea45c97870806ef3 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 16 Jan 2024 18:20:03 +0100 Subject: [PATCH 9/9] black --- integrations/cohere/tests/test_cohere_chat_generator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index bc6286d8e..e93db51fd 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -260,9 +260,7 @@ def test_live_run(self): ) @pytest.mark.integration def test_live_run_wrong_model(self, chat_messages): - component = CohereChatGenerator( - model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY") - ) + component = CohereChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) with pytest.raises(cohere.CohereAPIError, match="finetuned model something-obviously-wrong is not valid"): component.run(chat_messages)