diff --git a/docs/open_source/scan/scan_llm/index.md b/docs/open_source/scan/scan_llm/index.md index e365601a7a..ba51f6c65b 100644 --- a/docs/open_source/scan/scan_llm/index.md +++ b/docs/open_source/scan/scan_llm/index.md @@ -130,6 +130,22 @@ giskard.llm.set_default_client(claude_client) set_default_embedding(embed_client) ``` +:::::: +::::::{tab-item} Gemini + +```python +import os +import giskard + +import google.generativeai as genai + +from giskard.llm.client.gemini import GeminiClient + +genai.configure(api_key=os.environ["GEMINI_API_KEY"]) + +giskard.llm.set_default_client(GeminiClient()) +``` + :::::: ::::::{tab-item} Custom Client diff --git a/docs/open_source/setting_up/index.md b/docs/open_source/setting_up/index.md index 23f08ab4b5..d179c2c21a 100644 --- a/docs/open_source/setting_up/index.md +++ b/docs/open_source/setting_up/index.md @@ -84,6 +84,21 @@ giskard.llm.set_default_client(claude_client) set_default_embedding(embed_client) ``` +## Gemini Client Setup + +```python +import os +import giskard + +import google.generativeai as genai + +from giskard.llm.client.gemini import GeminiClient + +genai.configure(api_key=os.environ["GEMINI_API_KEY"]) + +giskard.llm.set_default_client(GeminiClient()) +``` + ## Custom Client Setup ```python diff --git a/docs/open_source/testset_generation/testset_generation/index.md b/docs/open_source/testset_generation/testset_generation/index.md index 87fd50f65f..2db6ae2150 100644 --- a/docs/open_source/testset_generation/testset_generation/index.md +++ b/docs/open_source/testset_generation/testset_generation/index.md @@ -162,6 +162,22 @@ giskard.llm.set_default_client(claude_client) set_default_embedding(embed_client) ``` +:::::: +::::::{tab-item} Gemini + +```python +import os +import giskard + +import google.generativeai as genai + +from giskard.llm.client.gemini import GeminiClient + +genai.configure(api_key=os.environ["GEMINI_API_KEY"]) + +giskard.llm.set_default_client(GeminiClient()) +``` + :::::: ::::::{tab-item} Custom Client ```python diff --git a/giskard/llm/client/gemini.py b/giskard/llm/client/gemini.py new file mode 100644 index 0000000000..08aa6ee77b --- /dev/null +++ b/giskard/llm/client/gemini.py @@ -0,0 +1,95 @@ +from typing import Optional, Sequence + +from logging import warning + +from ..config import LLMConfigurationError +from ..errors import LLMImportError +from . import LLMClient +from .base import ChatMessage + +try: + import google.generativeai as genai + from google.generativeai.types import ContentDict +except ImportError as err: + raise LLMImportError( + flavor="llm", + msg="To use Gemini models, please install the `genai` package with `pip install google-generativeai`", + ) from err + +AUTH_ERROR_MESSAGE = ( + "Could not get Response from Gemini API. Please make sure you have configured the API key by " + "setting GOOGLE_API_KEY in the environment." +) + + +def _format(messages: Sequence[ChatMessage]) -> Sequence[ContentDict]: + system_prompts = [] + content = [] + + for message in messages: + if message.role == "system": + system_prompts.append(message.content) + + if len(content) == 0: + content.append(ContentDict(role="model", parts=[])) + + content[0]["parts"].insert(0, f"# System:\n{message.content}") + + continue + + role = "model" if message.role == "assistant" else "user" + + # Consecutive messages need to be grouped + last_message = None if len(content) == 0 else content[-1] + if last_message is not None and last_message["role"] == role: + last_message["parts"].append(message.content) + continue + + content.append(ContentDict(role=message.role, parts=[message.content])) + + return content + + +class GeminiClient(LLMClient): + def __init__(self, model: str = "gemini-pro", _client=None): + self.model = model + self._client = _client or genai.GenerativeModel(self.model) + + def complete( + self, + messages: Sequence[ChatMessage], + temperature: float = 1.0, + max_tokens: Optional[int] = None, + caller_id: Optional[str] = None, + seed: Optional[int] = None, + format=None, + ) -> ChatMessage: + extra_params = dict() + if seed is not None: + extra_params["seed"] = seed + + if format: + warning(f"Unsupported format '{format}', ignoring.") + + try: + completion = self._client.generate_content( + contents=_format(messages), + generation_config=genai.types.GenerationConfig( + temperature=temperature, + max_output_tokens=max_tokens, + **extra_params, + ), + ) + except RuntimeError as err: + raise LLMConfigurationError(AUTH_ERROR_MESSAGE) from err + + self.logger.log_call( + prompt_tokens=self._client.count_tokens([m.content for m in messages]), + sampled_tokens=self._client.count_tokens(completion.text), + model=self.model, + client_class=self.__class__.__name__, + caller_id=caller_id, + ) + + # Assuming the response structure is similar to the ChatMessage structure + return ChatMessage(role=completion.candidates[0].content.role, content=completion.text) diff --git a/pyproject.toml b/pyproject.toml index bffaeb523c..f78e3a1e7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ test = [ "shap<0.45", # Fixing this to avoid changed on insights "ragas>=0.1.5", "nemoguardrails>=0.9.0", + "google-generativeai", ] doc = [ "furo>=2023.5.20", @@ -117,6 +118,7 @@ doc = [ "ipython==8.12.0", "scrapy", "requests", + "requests", ] [project.scripts] diff --git a/tests/llm/test_llm_client.py b/tests/llm/test_llm_client.py index 641d69ac4c..40e899a35b 100644 --- a/tests/llm/test_llm_client.py +++ b/tests/llm/test_llm_client.py @@ -3,6 +3,7 @@ import pydantic import pytest +from google.generativeai.types import ContentDict from mistralai.models.chat_completion import ChatCompletionResponse, ChatCompletionResponseChoice from mistralai.models.chat_completion import ChatMessage as MistralChatMessage from mistralai.models.chat_completion import FinishReason, UsageInfo @@ -12,6 +13,7 @@ from giskard.llm.client import ChatMessage from giskard.llm.client.bedrock import ClaudeBedrockClient +from giskard.llm.client.gemini import GeminiClient from giskard.llm.client.openai import OpenAIClient PYDANTIC_V2 = pydantic.__version__.startswith("2.") @@ -125,3 +127,31 @@ def test_claude_bedrock_client(): # Assert that the response is a ChatMessage and has the correct content assert isinstance(res, ChatMessage) assert res.content == "This is a test!" + + +def test_gemini_client(): + # Mock the Gemini client + gemini_api_client = Mock() + gemini_api_client.generate_content = MagicMock( + return_value=Mock(text="This is a test!", candidates=[Mock(content=Mock(role="assistant"))]) + ) + gemini_api_client.count_tokens = MagicMock( + side_effect=lambda text: sum(len(t.split()) for t in text) if isinstance(text, list) else len(text.split()) + ) + + # Initialize the GeminiClient with the mocked gemini_api_client + client = GeminiClient(model="gemini-pro", _client=gemini_api_client) + + # Call the complete method + res = client.complete([ChatMessage(role="user", content="Hello")], temperature=0.11, max_tokens=12) + print(res) + + # Assert that the generate_content method was called with the correct arguments + gemini_api_client.generate_content.assert_called_once() + assert gemini_api_client.generate_content.call_args[1]["contents"] == ([ContentDict(role="user", parts=["Hello"])]) + assert gemini_api_client.generate_content.call_args[1]["generation_config"].temperature == 0.11 + assert gemini_api_client.generate_content.call_args[1]["generation_config"].max_output_tokens == 12 + + # Assert that the response is a ChatMessage and has the correct content + assert isinstance(res, ChatMessage) + assert res.content == "This is a test!"