Skip to content

Commit

Permalink
Implement "tool" and "file" capabilities
Browse files Browse the repository at this point in the history
Still not completely finished, currently a work in progress
  • Loading branch information
Whitelisted1 committed Jun 4, 2024
1 parent 82b8e53 commit fa40ab2
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/hugchat/hugchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ def chat(
),
_stream_yield_all=_stream_yield_all,
web_search=web_search,
conversation=conversation
)
return msg

Expand Down
33 changes: 31 additions & 2 deletions src/hugchat/message.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Generator, Union

from .types.tool import Tool
from .types.file import File
from .types.message import Conversation
from .exceptions import ChatError, ModelOverloadedError
import json

RESPONSE_TYPE_FINAL = "finalAnswer"
RESPONSE_TYPE_STREAM = "stream"
RESPONSE_TYPE_TOOL = "tool" # with subtypes "call" and "result"
RESPONSE_TYPE_FILE = "file"
RESPONSE_TYPE_WEB = "webSearch"
RESPONSE_TYPE_STATUS = "status"
MSGTYPE_ERROR = "error"
Expand Down Expand Up @@ -59,6 +64,8 @@ class Message(Generator):
_stream_yield_all: bool = False
web_search: bool = False
web_search_sources: list = []
tools_used: list = []
files_created: list = []
_result_text: str = ""
web_search_done: bool = not web_search
msg_status: int = MSGSTATUS_PENDING
Expand All @@ -69,10 +76,12 @@ def __init__(
g: Generator,
_stream_yield_all: bool = False,
web_search: bool = False,
conversation: Conversation = None
) -> None:
self.g = g
self._stream_yield_all = _stream_yield_all
self.web_search = web_search
self.conversation = conversation

@property
def text(self) -> str:
Expand All @@ -97,8 +106,7 @@ def __next__(self) -> dict:
if self.error is not None:
raise self.error
else:
raise Exception(
"Message stauts is `Rejected` but no error found")
raise Exception("Message status is `Rejected` but no error found")

try:
a: dict = next(self.g)
Expand All @@ -118,6 +126,13 @@ def __next__(self) -> dict:
wss.title = source["title"]
wss.link = source["link"]
self.web_search_sources.append(wss)
elif t == RESPONSE_TYPE_TOOL:
if a["subtype"] == "result":
tool = Tool(a["uuid"], a["result"])
self.tools_used.append(tool)
elif t == RESPONSE_TYPE_FILE:
file = File(a["sha"], a["name"], a["mime"], self.conversation)
self.files_created.append(file)
elif "messageType" in a:
message_type: str = a["messageType"]
if message_type == MSGTYPE_ERROR:
Expand Down Expand Up @@ -185,6 +200,20 @@ def get_search_sources(self) -> list:
"""
return self.web_search_sources

def get_tools_used(self) -> list:
"""
:Return:
- self.tools_used
"""
return self.tools_used

def get_files_created(self) -> list:
"""
:Return:
- self.files_created
"""
return self.files_created

def search_enabled(self) -> bool:
"""
:Return:
Expand Down
24 changes: 24 additions & 0 deletions src/hugchat/types/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from .message import Conversation


class File:
'''
Class used to represent files created by the model
'''

def __init__(self, sha: str, name: str, mime: str, conversation: Conversation):
self.sha = sha
self.name = name
self.mime = mime

self.conversation = conversation
self.url = self.get_url()

def get_url(self) -> str:
print(self.conversation)
print(dir(self.conversation))
print(self.conversation.id)
return f"https://huggingface.co/chat/conversation/{self.conversation.id}/output/{self.sha}"

def __str__(self) -> str:
return f"File(url={self.url}, sha={self.sha}, name={self.name}, mime={self.mime})"
14 changes: 14 additions & 0 deletions src/hugchat/types/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass


@dataclass
class Tool:
'''
Class used to represent tools used by the model
'''

uuid: str
result: str

def __str__(self) -> str:
return f"Tool(uuid={self.uuid}, result={self.result})"

0 comments on commit fa40ab2

Please sign in to comment.