Skip to content

Commit

Permalink
Code reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmessiaen committed Jun 4, 2024
1 parent bc6313a commit c53d475
Showing 1 changed file with 82 additions and 65 deletions.
147 changes: 82 additions & 65 deletions giskard/llm/client/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Sequence
from typing import Dict, List, Optional, Sequence

import json
from abc import ABC, abstractmethod

from ..config import LLMConfigurationError
from ..errors import LLMImportError
Expand All @@ -15,80 +16,109 @@
) from err


def _format_messages_claude(messages: Sequence[ChatMessage]):
input_msg_prompt: List = []
system_prompts = []
class BaseBedrockClient(LLMClient, ABC):
def __init__(self, bedrock_runtime_client, model: str):
self._client = bedrock_runtime_client
self.model = model

for msg in messages:
# System prompt is a specific parameter in Claude
if msg.role.lower() == "system":
system_prompts.append(msg.content)
continue
@abstractmethod
def _format_body(
self,
messages: Sequence[ChatMessage],
temperature: float = 1,
max_tokens: Optional[int] = 1000,
caller_id: Optional[str] = None,
seed: Optional[int] = None,
format=None,
) -> Dict:
...

# Only role user and assistant are allowed
role = msg.role.lower()
role = role if role in ["assistant", "user"] else "user"
@abstractmethod
def _parse_completion(self, completion, caller_id: Optional[str] = None) -> ChatMessage:
...

# Consecutive messages need to be grouped
last_message = None if len(input_msg_prompt) == 0 else input_msg_prompt[-1]
if last_message is not None and last_message["role"] == role:
last_message["content"].append({"type": "text", "text": msg.content})
continue
def complete(
self,
messages: Sequence[ChatMessage],
temperature: float = 1,
max_tokens: Optional[int] = 1000,
caller_id: Optional[str] = None,
seed: Optional[int] = None,
format=None,
) -> ChatMessage:
# create the json body to send to the API
body = self._format_body(messages, temperature, max_tokens, caller_id, seed, format)

input_msg_prompt.append({"role": role, "content": [{"type": "text", "text": msg.content}]})
# invoke the model and get the response
try:
accept = "application/json"
contentType = "application/json"
response = self._client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType)
completion = json.loads(response.get("body").read())
except RuntimeError as err:
raise LLMConfigurationError("Could not get response from Bedrock API") from err

return input_msg_prompt, "\n".join(system_prompts)
return self._parse_completion(completion, caller_id)


class ClaudeBedrockClient(LLMClient):
class ClaudeBedrockClient(BaseBedrockClient):
def __init__(
self,
bedrock_runtime_client,
model: str = "anthropic.claude-3-sonnet-20240229-v1:0",
anthropic_version: str = "bedrock-2023-05-31",
):
self._client = bedrock_runtime_client
self.model = model
# only supporting claude 3
if "claude-3" not in self.model:
raise LLMConfigurationError(f"Only claude-3 models are supported as of now, got {self.model}")

super().__init__(bedrock_runtime_client, model)
self.anthropic_version = anthropic_version

def complete(
def _format_body(
self,
messages: Sequence[ChatMessage],
temperature: float = 1,
max_tokens: Optional[int] = 1000,
caller_id: Optional[str] = None,
seed: Optional[int] = None,
format=None,
) -> ChatMessage:
# only supporting claude 3 to start
if "claude-3" not in self.model:
raise LLMConfigurationError(f"Only claude-3 models are supported as of now, got {self.model}")
) -> Dict:
input_msg_prompt: List = []
system_prompts = []

messages, system = _format_messages_claude(messages)
for msg in messages:
# System prompt is a specific parameter in Claude
if msg.role.lower() == "system":
system_prompts.append(msg.content)
continue

# create the json body to send to the API
body = json.dumps(
# Only role user and assistant are allowed
role = msg.role.lower()
role = role if role in ["assistant", "user"] else "user"

# Consecutive messages need to be grouped
last_message = None if len(input_msg_prompt) == 0 else input_msg_prompt[-1]
if last_message is not None and last_message["role"] == role:
last_message["content"].append({"type": "text", "text": msg.content})
continue

input_msg_prompt.append({"role": role, "content": [{"type": "text", "text": msg.content}]})

return json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"temperature": temperature,
"system": system,
"system": "\n".join(system_prompts),
"messages": messages,
}
)

# invoke the model and get the response
try:
accept = "application/json"
contentType = "application/json"
response = self._client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType)
completion = json.loads(response.get("body").read())
except RuntimeError as err:
raise LLMConfigurationError("Could not get response from Bedrock API") from err

def _parse_completion(self, completion, caller_id: Optional[str] = None) -> ChatMessage:
self.logger.log_call(
prompt_tokens=completion["usage"]["input_tokens"],
sampled_tokens=completion["usage"]["input_tokens"],
sampled_tokens=completion["usage"]["output_tokens"],
model=self.model,
client_class=self.__class__.__name__,
caller_id=caller_id,
Expand All @@ -98,52 +128,39 @@ def complete(
return ChatMessage(role="assistant", content=msg)


class LLamaBedrockClient(LLMClient):
def __init__(
self,
bedrock_runtime_client,
model: str = "meta.llama3-8b-instruct-v1:0",
):
self._client = bedrock_runtime_client
self.model = model
class LLamaBedrockClient(BaseBedrockClient):
def __init__(self, bedrock_runtime_client, model: str = "meta.llama3-8b-instruct-v1:0"):
# only supporting llama
if "llama" not in self.model:
raise LLMConfigurationError(f"Only Llama models are supported as of now, got {self.model}")

def complete(
super().__init__(bedrock_runtime_client, model)

def _format_body(
self,
messages: Sequence[ChatMessage],
temperature: float = 1,
max_tokens: Optional[int] = 1000,
caller_id: Optional[str] = None,
seed: Optional[int] = None,
format=None,
) -> ChatMessage:
# only supporting llama
if "llama" not in self.model:
raise LLMConfigurationError(f"Only Llama models are supported as of now, got {self.model}")

) -> Dict:
# Create the messages format needed for llama bedrock specifically
prompts = []
for msg in messages:
prompts.append(f"# {msg.role}:\n{msg.content}\n")

# create the json body to send to the API
messages = "\n".join(prompts)
body = json.dumps(
return json.dumps(
{
"max_gen_len": max_tokens,
"temperature": temperature,
"prompt": f"{messages}\n# assistant:\n",
}
)

# invoke the model and get the response
try:
accept = "application/json"
contentType = "application/json"
response = self._client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType)
completion = json.loads(response.get("body").read())
except RuntimeError as err:
raise LLMConfigurationError("Could not get response from Bedrock API") from err

def _parse_completion(self, completion, caller_id: Optional[str] = None) -> ChatMessage:
self.logger.log_call(
prompt_tokens=completion["prompt_token_count"],
sampled_tokens=completion["generation_token_count"],
Expand Down

0 comments on commit c53d475

Please sign in to comment.