Skip to content

Commit

Permalink
refactor: pass a role string to OpenAI API (#7404)
Browse files Browse the repository at this point in the history
* draft

* rm unused imports
  • Loading branch information
anakin87 authored Mar 22, 2024
1 parent e779d43 commit c789f90
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 39 deletions.
3 changes: 2 additions & 1 deletion haystack/components/generators/chat/hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,9 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
raise RuntimeError("Please call warm_up() before running LLM inference.")

# apply either model's chat template or the user-provided one
formatted_messages = [message.to_openai_format() for message in messages]
prepared_prompt: str = self.tokenizer.apply_chat_template(
conversation=messages, chat_template=self.chat_template, tokenize=False
conversation=formatted_messages, chat_template=self.chat_template, tokenize=False
)
prompt_token_count: int = len(self.tokenizer.encode(prepared_prompt, add_special_tokens=False))

Expand Down
19 changes: 2 additions & 17 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import copy
import dataclasses
import json
from typing import Any, Callable, Dict, List, Optional, Union

from openai import OpenAI, Stream # type: ignore
from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
Expand Down Expand Up @@ -169,7 +168,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = self._convert_to_openai_format(messages)
openai_formatted_messages = [message.to_openai_format() for message in messages]

chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model,
Expand Down Expand Up @@ -204,20 +203,6 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,

return {"replies": completions}

def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
"""
Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API.
:param messages: The list of ChatMessage.
:return: The list of messages in the format expected by the OpenAI API.
"""
openai_chat_message_format = {"role", "content", "name"}
openai_formatted_messages = []
for m in messages:
message_dict = dataclasses.asdict(m)
filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v}
openai_formatted_messages.append(filtered_message)
return openai_formatted_messages

def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
Expand Down
20 changes: 1 addition & 19 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import dataclasses
from typing import Any, Callable, Dict, List, Optional, Union

from openai import OpenAI, Stream
Expand Down Expand Up @@ -164,7 +163,7 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = self._convert_to_openai_format(messages)
openai_formatted_messages = [message.to_openai_format() for message in messages]

completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model,
Expand Down Expand Up @@ -200,23 +199,6 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
"meta": [message.meta for message in completions],
}

def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
"""
Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API.
:param messages:
The list of ChatMessage.
:returns:
The list of messages in the format expected by the OpenAI API.
"""
openai_chat_message_format = {"role", "content", "name"}
openai_formatted_messages = []
for m in messages:
message_dict = dataclasses.asdict(m)
filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v}
openai_formatted_messages.append(filtered_message)
return openai_formatted_messages

def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
Expand Down
16 changes: 16 additions & 0 deletions haystack/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ class ChatMessage:
name: Optional[str]
meta: Dict[str, Any] = field(default_factory=dict, hash=False)

def to_openai_format(self) -> Dict[str, Any]:
"""
Convert the message to the format expected by OpenAI's Chat API.
See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details.
:returns: A dictionary with the following key:
- `role`
- `content`
- `name` (optional)
"""
msg = {"role": self.role.value, "content": self.content}
if self.name:
msg["name"] = self.name

return msg

def is_from(self, role: ChatRole) -> bool:
"""
Check if the message is from a specific role.
Expand Down
19 changes: 17 additions & 2 deletions test/dataclasses/test_chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,23 @@ def test_from_function_with_empty_name():
assert message.name == ""


def test_to_openai_format():
message = ChatMessage.from_system("You are good assistant")
assert message.to_openai_format() == {"role": "system", "content": "You are good assistant"}

message = ChatMessage.from_user("I have a question")
assert message.to_openai_format() == {"role": "user", "content": "I have a question"}

message = ChatMessage.from_function("Function call", "function_name")
assert message.to_openai_format() == {"role": "function", "content": "Function call", "name": "function_name"}


@pytest.mark.integration
def test_apply_chat_templating_on_chat_message():
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenized_messages = tokenizer.apply_chat_template(messages, tokenize=False)
formatted_messages = [m.to_openai_format() for m in messages]
tokenized_messages = tokenizer.apply_chat_template(formatted_messages, tokenize=False)
assert tokenized_messages == "<|system|>\nYou are good assistant</s>\n<|user|>\nI have a question</s>\n"


Expand All @@ -61,5 +73,8 @@ def test_apply_custom_chat_templating_on_chat_message():
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
# could be any tokenizer, let's use the one we already likely have in cache
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenized_messages = tokenizer.apply_chat_template(messages, chat_template=anthropic_template, tokenize=False)
formatted_messages = [m.to_openai_format() for m in messages]
tokenized_messages = tokenizer.apply_chat_template(
formatted_messages, chat_template=anthropic_template, tokenize=False
)
assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:"

0 comments on commit c789f90

Please sign in to comment.