Skip to content

Commit

Permalink
Added seed and format param to ClaudeBedrockClient.complete
Browse files Browse the repository at this point in the history
Those param will be ignored but that's not a big issue.

Missing them would cause certain feature to totally fail using this model.
  • Loading branch information
kevinmessiaen committed Apr 22, 2024
1 parent 1e45fa2 commit c1da31c
Showing 1 changed file with 34 additions and 52 deletions.
86 changes: 34 additions & 52 deletions giskard/llm/client/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Sequence

import json

from ..config import LLMConfigurationError
Expand All @@ -15,27 +16,28 @@


class ClaudeBedrockClient(LLMClient):
def __init__(self,
bedrock_runtime_client,
model: str = "anthropic.claude-3-sonnet-20240229-v1:0",
anthropic_version: str = "bedrock-2023-05-31"):
def __init__(
self,
bedrock_runtime_client,
model: str = "anthropic.claude-3-sonnet-20240229-v1:0",
anthropic_version: str = "bedrock-2023-05-31",
):
self._client = bedrock_runtime_client
self.model = model
self.anthropic_version = anthropic_version

def complete(
self,
messages: Sequence[ChatMessage],
temperature: float = 1,
max_tokens: Optional[int] = 1000,
caller_id: Optional[str] = None,
self,
messages: Sequence[ChatMessage],
temperature: float = 1,
max_tokens: Optional[int] = 1000,
caller_id: Optional[str] = None,
seed: Optional[int] = None,
format=None,
) -> ChatMessage:

# only supporting claude 3 to start
if 'claude-3' not in self.model:
raise LLMConfigurationError(
f"Only claude-3 models are supported as of now, got {self.model}"
)
if "claude-3" not in self.model:
raise LLMConfigurationError(f"Only claude-3 models are supported as of now, got {self.model}")

# extract system prompt from messages
system_prompt = ""
Expand All @@ -48,57 +50,37 @@ def complete(
input_msg_prompt = []
for msg in messages:
if msg.role.lower() == "assistant":
input_msg_prompt.append(
{
"role": "assistant",
"content": [
{
"type": "text",
"text": msg.content
}
]
}
)
input_msg_prompt.append({"role": "assistant", "content": [{"type": "text", "text": msg.content}]})
else:
input_msg_prompt.append(
{
"role": "user",
"content": [
{
"type": "text",
"text": msg.content
}
]
}
)
input_msg_prompt.append({"role": "user", "content": [{"type": "text", "text": msg.content}]})

# create the json body to send to the API
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"temperature": temperature,
"system": system_prompt,
"messages": input_msg_prompt
})
body = json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"temperature": temperature,
"system": system_prompt,
"messages": input_msg_prompt,
}
)

# invoke the model and get the response
try:
accept = 'application/json'
contentType = 'application/json'
response = self._client.invoke_model(
body=body, modelId=self.model, accept=accept, contentType=contentType
)
completion = json.loads(response.get('body').read())
accept = "application/json"
contentType = "application/json"
response = self._client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType)
completion = json.loads(response.get("body").read())
except RuntimeError as err:
raise LLMConfigurationError("Could not get response from Bedrock API") from err

self.logger.log_call(
prompt_tokens=completion['usage']['input_tokens'],
sampled_tokens=completion['usage']['input_tokens'],
prompt_tokens=completion["usage"]["input_tokens"],
sampled_tokens=completion["usage"]["input_tokens"],
model=self.model,
client_class=self.__class__.__name__,
caller_id=caller_id,
)

msg = completion['content'][0]['text']
msg = completion["content"][0]["text"]
return ChatMessage(role="assistant", content=msg)

0 comments on commit c1da31c

Please sign in to comment.