Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemini client #1953

Merged
merged 21 commits into from
Jul 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Code formatting
  • Loading branch information
kevinmessiaen committed Jun 20, 2024
commit a6f1fc15ea18110af433260ec343be7dcffd3e0f
12 changes: 5 additions & 7 deletions tests/llm/test_llm_client.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import json
from unittest.mock import MagicMock, Mock

from google.generativeai.types import ContentDict
from mistralai.models.chat_completion import ChatCompletionResponse, ChatCompletionResponseChoice
from mistralai.models.chat_completion import ChatMessage as MistralChatMessage
from mistralai.models.chat_completion import FinishReason, UsageInfo
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from google.generativeai.types import ContentDict

from giskard.llm.client import ChatMessage
from giskard.llm.client.bedrock import ClaudeBedrockClient
from giskard.llm.client.gemini import GeminiClient
from giskard.llm.client.mistral import MistralClient
from giskard.llm.client.openai import OpenAIClient
from giskard.llm.client.gemini import GeminiClient

DEMO_OPENAI_RESPONSE = ChatCompletion(
id="chatcmpl-abc123",
Expand Down Expand Up @@ -122,14 +122,12 @@ def test_claude_bedrock_client():
assert isinstance(res, ChatMessage)
assert res.content == "This is a test!"


def test_gemini_client():
# Mock the Gemini client
gemini_api_client = Mock()
gemini_api_client.generate_content = MagicMock(
return_value=Mock(
text="This is a test!",
candidates=[Mock(content=Mock(role="assistant"))]
)
return_value=Mock(text="This is a test!", candidates=[Mock(content=Mock(role="assistant"))])
)
gemini_api_client.count_tokens = MagicMock(
side_effect=lambda text: sum(len(t.split()) for t in text) if isinstance(text, list) else len(text.split())
Expand All @@ -150,4 +148,4 @@ def test_gemini_client():

# Assert that the response is a ChatMessage and has the correct content
assert isinstance(res, ChatMessage)
assert res.content == "This is a test!"
assert res.content == "This is a test!"
Loading