Skip to content

Commit

Permalink
Group message by role when using claude bedrock client
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmessiaen committed Jun 4, 2024
1 parent 6135e13 commit 1764966
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions giskard/llm/client/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence
from typing import List, Optional, Sequence

import json

Expand All @@ -15,6 +15,31 @@
) from err


def _format_messages_claude(messages: Sequence[ChatMessage]):
input_msg_prompt: List = []
system_prompts = []

for msg in messages:
# System prompt is a specific parameter in Claude
if msg.role.lower() == "system":
system_prompts.append(msg.content)
continue

# Only role user and assistant are allowed
role = msg.role.lower()
role = role if role in ["assistant", "user"] else "user"

# Consecutive messages need to be grouped
last_message = None if len(input_msg_prompt) == 0 else input_msg_prompt[-1]
if last_message is not None and last_message["role"] == role:
last_message["content"].append({"type": "text", "text": msg.content})
continue

input_msg_prompt.append({"role": role, "content": [{"type": "text", "text": msg.content}]})

return input_msg_prompt, "\n".join(system_prompts)


class ClaudeBedrockClient(LLMClient):
def __init__(
self,
Expand All @@ -39,25 +64,16 @@ def complete(
if "claude-3" not in self.model:
raise LLMConfigurationError(f"Only claude-3 models are supported as of now, got {self.model}")

# Create the messages format needed for bedrock specifically
input_msg_prompt = []
system_prompts = []
for msg in messages:
if msg.role.lower() == "system":
system_prompts.append(msg.content)
elif msg.role.lower() == "assistant":
input_msg_prompt.append({"role": "assistant", "content": [{"type": "text", "text": msg.content}]})
else:
input_msg_prompt.append({"role": "user", "content": [{"type": "text", "text": msg.content}]})
messages, system = _format_messages_claude(messages)

# create the json body to send to the API
body = json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"temperature": temperature,
"system": "\n".join(system_prompts),
"messages": input_msg_prompt,
"system": system,
"messages": messages,
}
)

Expand Down

0 comments on commit 1764966

Please sign in to comment.