diff --git a/.gitignore b/.gitignore index 5e8f6be..a70f842 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,6 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -.vscode/ \ No newline at end of file +.vscode/ + +gpt.yml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..38c1f34 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,23 @@ +FROM python:3.10-bullseye +WORKDIR /app + +COPY requirements.txt ./ +RUN --mount=type=cache,target=/root/.cache/pip pip install -r requirements.txt + +COPY requirements_docker.txt ./ +RUN --mount=type=cache,target=/root/.cache/pip pip install -r requirements_docker.txt + +COPY . . + +RUN mkdir -p /mnt/output + +RUN adduser --disabled-password gpt +USER gpt +RUN mkdir -p $HOME/.config/gpt-cli +RUN cp /app/gpt.yml $HOME/.config/gpt-cli/gpt.yml + + +WORKDIR /mnt/output + +ENV GPTCLI_ALLOW_CODE_EXECUTION=1 +ENTRYPOINT ["python", "/app/gpt.py"] diff --git a/gpt.py b/gpt.py index 9040d2a..671f174 100755 --- a/gpt.py +++ b/gpt.py @@ -24,6 +24,7 @@ choose_config_file, read_yaml_config, ) +from gptcli.interpreter import CodeInterpreterListener from gptcli.llama import init_llama_models from gptcli.logging import LoggingChatListener from gptcli.cost import PriceChatListener @@ -89,6 +90,12 @@ def parse_args(config: GptCliConfig): default=config.log_file, help="The file to write logs to", ) + parser.add_argument( + "--history_file", + type=str, + default=config.history_file, + help="The file to write chat history to", + ) parser.add_argument( "--log_level", type=str, @@ -215,8 +222,14 @@ def __init__(self, assistant: Assistant, markdown: bool, show_price: bool): if show_price: listeners.append(PriceChatListener(assistant)) + if os.environ.get("GPTCLI_ALLOW_CODE_EXECUTION") == "1": + listeners.append(CodeInterpreterListener("python")) + listener = CompositeChatListener(listeners) - super().__init__(assistant, listener) + super().__init__( + assistant, + listener, + ) def run_interactive(args, assistant): @@ -224,7 +237,9 @@ def run_interactive(args, assistant): session = CLIChatSession( assistant=assistant, markdown=args.markdown, show_price=args.show_price ) - history_filename = os.path.expanduser("~/.config/gpt-cli/history") + history_filename = args.history_file or os.path.expanduser( + "~/.config/gpt-cli/history" + ) os.makedirs(os.path.dirname(history_filename), exist_ok=True) input_provider = CLIUserInputProvider(history_filename=history_filename) session.loop(input_provider) diff --git a/gpt.yml.template b/gpt.yml.template new file mode 100644 index 0000000..11bcff2 --- /dev/null +++ b/gpt.yml.template @@ -0,0 +1,12 @@ +markdown: True +openai_api_key: +log_file: /mnt/gpt.log +log_level: DEBUG +history_file: /mnt/history +assistants: + python: + model: gpt-4-0613 + enable_code_execution: True + messages: + - { role: "system", content: "You are a helpful assistant. You have access to a Python environment. You can install missing packages. You have access to the internet. The user can see the code you are executing and its output: do not repeat them to the user verbatim. Pre-installed packages: numpy, matplotlib, scipy, scikit-learn, pandas, ipython, ipykernel." } + diff --git a/gptcli/anthropic.py b/gptcli/anthropic.py index a0cda26..041a2da 100644 --- a/gptcli/anthropic.py +++ b/gptcli/anthropic.py @@ -1,8 +1,15 @@ +import logging import os from typing import Iterator, List import anthropic -from gptcli.completion import CompletionProvider, Message +from gptcli.completion import ( + Completion, + CompletionProvider, + Message, + make_completion, + make_completion_iter, +) api_key = os.environ.get("ANTHROPIC_API_KEY") @@ -25,7 +32,10 @@ def role_to_name(role: str) -> str: def make_prompt(messages: List[Message]) -> str: prompt = "\n".join( - [f"{role_to_name(message['role'])}{message['content']}" for message in messages] + [ + f"{role_to_name(message['role'])}{message.get('content', '')}" + for message in messages + ] ) prompt += f"{role_to_name('assistant')}" return prompt @@ -33,8 +43,15 @@ def make_prompt(messages: List[Message]) -> str: class AnthropicCompletionProvider(CompletionProvider): def complete( - self, messages: List[Message], args: dict, stream: bool = False - ) -> Iterator[str]: + self, + messages: List[Message], + args: dict, + stream: bool = False, + enable_code_execution: bool = False, + ) -> Iterator[Completion]: + if enable_code_execution: + raise ValueError("Code execution is not supported by Anthropic models") + kwargs = { "prompt": make_prompt(messages), "stop_sequences": [anthropic.HUMAN_PROMPT], @@ -49,14 +66,19 @@ def complete( client = get_client() if stream: response = client.completion_stream(**kwargs) - else: - response = [client.completion(**kwargs)] - prev_completion = "" - for data in response: - next_completion = data["completion"] - yield next_completion[len(prev_completion) :] - prev_completion = next_completion + def content_iter() -> Iterator[str]: + prev_completion = "" + for data in response: + next_completion = data["completion"] + yield next_completion[len(prev_completion) :] + prev_completion = next_completion + + for x in make_completion_iter(content_iter()): + yield x + else: + response = client.completion(**kwargs) + yield make_completion(response["completion"], finish_reason="stop") def num_tokens_from_messages_anthropic(messages: List[Message], model: str) -> int: @@ -65,4 +87,4 @@ def num_tokens_from_messages_anthropic(messages: List[Message], model: str) -> i def num_tokens_from_completion_anthropic(message: Message, model: str) -> int: - return anthropic.count_tokens(message["content"]) + return anthropic.count_tokens(message.get("content") or "") diff --git a/gptcli/assistant.py b/gptcli/assistant.py index 6adac67..8398f2c 100644 --- a/gptcli/assistant.py +++ b/gptcli/assistant.py @@ -4,7 +4,7 @@ import platform from typing import Any, Dict, Iterator, Optional, TypedDict, List -from gptcli.completion import CompletionProvider, ModelOverrides, Message +from gptcli.completion import Completion, CompletionProvider, ModelOverrides, Message from gptcli.google import GoogleCompletionProvider from gptcli.llama import LLaMACompletionProvider from gptcli.openai import OpenAICompletionProvider @@ -16,12 +16,14 @@ class AssistantConfig(TypedDict, total=False): model: str temperature: float top_p: float + enable_code_execution: bool CONFIG_DEFAULTS = { "model": "gpt-3.5-turbo", "temperature": 0.7, "top_p": 1.0, + "enable_code_execution": False, } DEFAULT_ASSISTANTS: Dict[str, AssistantConfig] = { @@ -89,7 +91,7 @@ def init_messages(self) -> List[Message]: return self.config.get("messages", [])[:] def supported_overrides(self) -> List[str]: - return ["model", "temperature", "top_p"] + return ["model", "temperature", "top_p", "enable_code_execution"] def _param(self, param: str, override_params: ModelOverrides) -> Any: # If the param is in the override_params, use that value @@ -101,9 +103,15 @@ def _param(self, param: str, override_params: ModelOverrides) -> Any: def complete_chat( self, messages, override_params: ModelOverrides = {}, stream: bool = True - ) -> Iterator[str]: + ) -> Iterator[Completion]: model = self._param("model", override_params) completion_provider = get_completion_provider(model) + + enable_code_execution = ( + bool(self._param("enable_code_execution", override_params)) + and os.environ.get("GPTCLI_ALLOW_CODE_EXECUTION") == "1" + ) + return completion_provider.complete( messages, { @@ -112,6 +120,7 @@ def complete_chat( "top_p": float(self._param("top_p", override_params)), }, stream, + enable_code_execution, ) diff --git a/gptcli/cli.py b/gptcli/cli.py index 4536596..f7e34e3 100644 --- a/gptcli/cli.py +++ b/gptcli/cli.py @@ -1,4 +1,8 @@ +import base64 +import logging import re +import json +from imgcat import imgcat from prompt_toolkit import PromptSession from prompt_toolkit.history import FileHistory from openai import OpenAIError, InvalidRequestError @@ -9,6 +13,7 @@ from typing import Any, Dict, Optional, Tuple from rich.text import Text +from gptcli.completion import FunctionCall, Message, merge_dicts from gptcli.session import ( ALL_COMMANDS, COMMAND_CLEAR, @@ -32,7 +37,7 @@ class StreamingMarkdownPrinter: def __init__(self, console: Console, markdown: bool): self.console = console - self.current_text = "" + self.current_message = {} self.markdown = markdown self.live: Optional[Live] = None @@ -44,14 +49,55 @@ def __enter__(self) -> "StreamingMarkdownPrinter": self.live.__enter__() return self - def print(self, text: str): - self.current_text += text + def _format_function_call(self, function_call: FunctionCall) -> str: + text = "" + if function_call.get("name") == "python_eval": + source = function_call.get("arguments", "") + try: + source = json.loads(source).get("source", "") + except: + source = source + '"}' + try: + source = json.loads(source).get("source", "") + except: + source = "" + + text += "\n\nExecuting Python code:\n" + text += f"```python\n{source}\n```" + else: + function_name = function_call.get("name", "?") + function_arguments = function_call.get("arguments", {}) + text += f"""\n +Calling function: + +``` +{function_name}({function_arguments}) +```""" + return text + + def print(self, message_delta: Message): + self.current_message = merge_dicts(self.current_message, message_delta) + if self.markdown: assert self.live - content = Markdown(self.current_text, style="green") + text = self.current_message.get("content", "") + + function_call = self.current_message.get("function_call") + if function_call: + text += self._format_function_call(function_call) + + content = Markdown(text, style="green") self.live.update(content) self.live.refresh() else: + text = message_delta.get("content") or "" + function_call = message_delta.get("function_call") + if function_call: + if "name" in function_call: + text += function_call["name"] + if "arguments" in function_call: + text += function_call["arguments"] + self.console.print(Text(text, style="green"), end="") def __exit__(self, *args): @@ -66,17 +112,29 @@ def __init__(self, console: Console, markdown: bool): self.console = console self.markdown = markdown self.printer = StreamingMarkdownPrinter(self.console, self.markdown) - self.first_token = True def __enter__(self): self.printer.__enter__() return self - def on_next_token(self, token: str): - if self.first_token and token.startswith(" "): - token = token[1:] - self.first_token = False - self.printer.print(token) + def on_message_delta(self, message_delta: Message): + self.printer.print(message_delta) + + def on_function_result(self, result: dict): + self.console.print(Text("Function result:", style="yellow")) + if "image/png" in result: + image_base64 = result["image/png"] + image_bytes = base64.b64decode(image_base64) + imgcat(image_bytes) + if "text/plain" in result: + text = result["text/plain"] + if self.markdown: + content = Markdown( + f"```\n{text}\n```", + ) + else: + content = Text(text, style="yellow") + self.console.print(content) def __exit__(self, *args): self.printer.__exit__(*args) diff --git a/gptcli/completion.py b/gptcli/completion.py index ea82d8f..292cf42 100644 --- a/gptcli/completion.py +++ b/gptcli/completion.py @@ -1,21 +1,85 @@ from abc import abstractmethod -from typing import Iterator, List, TypedDict +import logging +from typing import Iterator, List, Optional, TypedDict +from typing_extensions import Required -class Message(TypedDict): - role: str - content: str +class FunctionCall(TypedDict, total=False): + name: str + arguments: str + + +class Message(TypedDict, total=False): + role: Required[str] + content: Optional[str] + name: str + function_call: FunctionCall + + +def merge_dicts(a, b): + """ + Given two nested dicts with string values, merge dict `b` into dict `a`, concatenating + string values. + """ + for key, value in b.items(): + if isinstance(value, dict): + a[key] = merge_dicts(a.get(key, {}), value) + elif value is not None: + a[key] = a.get(key, "") + value + return a class ModelOverrides(TypedDict, total=False): model: str temperature: float top_p: float + enable_code_execution: bool + + +class CompletionDelta(TypedDict): + content: Optional[str] + function_call: Optional[FunctionCall] + + +class Completion(TypedDict): + delta: Message + finish_reason: Optional[str] + + +def make_completion( + content_delta: str, + role: str = "assistant", + finish_reason: Optional[str] = None, +) -> Completion: + delta: Message = { + "role": role, + "content": content_delta, + } + return { + "delta": delta, + "finish_reason": finish_reason, + } + + +def make_completion_iter( + content_iter: Iterator[str], + role: str = "assistant", + finish_reason: Optional[str] = "stop", +) -> Iterator[Completion]: + logging.debug("make_completion_iter") + yield make_completion("", role=role) + for content in content_iter: + yield make_completion(content, role="") + yield make_completion("", role="", finish_reason=finish_reason) class CompletionProvider: @abstractmethod def complete( - self, messages: List[Message], args: dict, stream: bool = False - ) -> Iterator[str]: + self, + messages: List[Message], + args: dict, + stream: bool = False, + enable_code_execution: bool = False, + ) -> Iterator[Completion]: pass diff --git a/gptcli/composite.py b/gptcli/composite.py index ab74a11..2bf6dd0 100644 --- a/gptcli/composite.py +++ b/gptcli/composite.py @@ -2,7 +2,7 @@ from gptcli.session import ChatListener, ResponseStreamer -from typing import List +from typing import List, Optional class CompositeResponseStreamer(ResponseStreamer): @@ -14,9 +14,13 @@ def __enter__(self): streamer.__enter__() return self - def on_next_token(self, token: str): + def on_message_delta(self, message_delta: Message): for streamer in self.streamers: - streamer.on_next_token(token) + streamer.on_message_delta(message_delta) + + def on_function_result(self, result: dict): + for streamer in self.streamers: + streamer.on_function_result(result) def __exit__(self, *args): for streamer in self.streamers: @@ -57,3 +61,9 @@ def on_chat_response( ): for listener in self.listeners: listener.on_chat_response(messages, response, overrides) + + def on_function_call(self, function_name: str, **kwargs) -> Optional[str]: + for listener in self.listeners: + result = listener.on_function_call(function_name, **kwargs) + if result is not None: + return result diff --git a/gptcli/config.py b/gptcli/config.py index 3cb9070..1a05914 100644 --- a/gptcli/config.py +++ b/gptcli/config.py @@ -23,6 +23,7 @@ class GptCliConfig: anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY") google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY") log_file: Optional[str] = None + history_file: Optional[str] = None log_level: str = "INFO" assistants: Dict[str, AssistantConfig] = {} interactive: Optional[bool] = None diff --git a/gptcli/google.py b/gptcli/google.py index e8c8c67..66bc844 100644 --- a/gptcli/google.py +++ b/gptcli/google.py @@ -1,6 +1,6 @@ from typing import Iterator, List import google.generativeai as genai -from gptcli.completion import CompletionProvider, Message +from gptcli.completion import Completion, CompletionProvider, Message, make_completion def role_to_author(role: str) -> str: @@ -14,11 +14,16 @@ def role_to_author(role: str) -> str: def make_prompt(messages: List[Message]): system_messages = [ - message["content"] for message in messages if message["role"] == "system" + message.get("content") or "" + for message in messages + if message["role"] == "system" ] context = "\n".join(system_messages) prompt = [ - {"author": role_to_author(message["role"]), "content": message["content"]} + { + "author": role_to_author(message["role"]), + "content": message.get("content", ""), + } for message in messages if message["role"] != "system" ] @@ -27,8 +32,15 @@ def make_prompt(messages: List[Message]): class GoogleCompletionProvider(CompletionProvider): def complete( - self, messages: List[Message], args: dict, stream: bool = False - ) -> Iterator[str]: + self, + messages: List[Message], + args: dict, + stream: bool = False, + enable_code_execution: bool = False, + ) -> Iterator[Completion]: + if enable_code_execution: + raise ValueError("Code execution is not supported by Google models") + context, prompt = make_prompt(messages) kwargs = { "context": context, @@ -40,4 +52,4 @@ def complete( kwargs["top_p"] = args["top_p"] response = genai.chat(**kwargs) - yield response.last + yield make_completion(response.last, finish_reason="stop") diff --git a/gptcli/interpreter.py b/gptcli/interpreter.py new file mode 100644 index 0000000..281c4e5 --- /dev/null +++ b/gptcli/interpreter.py @@ -0,0 +1,88 @@ +import logging +from typing import Optional +from jupyter_client.manager import KernelManager +from queue import Empty + +from gptcli.session import ChatListener + + +class CodeInterpreterSession: + def __init__(self): + self.logger = logging.getLogger("gptcli-code-interpreter") + + self.km = KernelManager() + self.km.start_kernel() + + self.client = self.km.client() + self.client.start_channels() + + # allow installing packages + self.execute("%colors NoColor") + self.execute("%load_ext autoreload") + self.execute("%autoreload 2") + self.execute("%matplotlib inline") + self.execute( + """ +import matplotlib.pyplot as plt +plt.ioff() +""" + ) + + def execute(self, code: str) -> dict: + self.logger.debug("Executing code: '%s'", code) + + msg_id = self.client.execute(code) + state = "busy" + output = {} + while state != "idle": + try: + msg = self.client.get_iopub_msg(timeout=1) + content = msg["content"] + msg_type = msg["msg_type"] + + if msg_type == "execute_result" or msg_type == "display_data": + output = content["data"] + elif msg_type == "stream": + output["text/plain"] = content["text"] + elif msg_type == "error": + output["text/plain"] = "\n".join(content["traceback"]) + elif msg_type == "status": + state = content["execution_state"] + except Empty: + pass + except KeyboardInterrupt: + self.km.interrupt_kernel() + break + + self.logger.debug("Code execution result: %s", output) + + return output + + def __del__(self): + self.client.stop_channels() + self.km.shutdown_kernel() + + +class CodeInterpreterListener(ChatListener): + def __init__(self, function_name: str): + self.session: Optional[CodeInterpreterSession] = None + self.function_name = function_name + + def on_chat_clear(self): + self.session = None + + def on_function_call(self, function_name: str, **kwargs) -> Optional[dict]: + source = None + if function_name == self.function_name: + source = kwargs["source"] + elif function_name == "pip_install": + source = f"%pip install -qq --no-color {kwargs['package']}" + + if source: + if self.session is None: + self.session = CodeInterpreterSession() + result = self.session.execute(source) + if function_name == "pip_install": + del self.session + self.session = None + return result diff --git a/gptcli/llama.py b/gptcli/llama.py index 536ed65..1065f30 100644 --- a/gptcli/llama.py +++ b/gptcli/llama.py @@ -1,9 +1,16 @@ +import logging import os import sys from typing import Iterator, List, Optional, TypedDict, cast -from llama_cpp import Completion, CompletionChunk, Llama +from llama_cpp import Completion as LlamaCompletion, CompletionChunk, Llama -from gptcli.completion import CompletionProvider, Message +from gptcli.completion import ( + CompletionProvider, + Message, + Completion, + make_completion, + make_completion_iter, +) class LLaMAModelConfig(TypedDict): @@ -40,7 +47,7 @@ def role_to_name(role: str, model_config: LLaMAModelConfig) -> str: def make_prompt(messages: List[Message], model_config: LLaMAModelConfig) -> str: prompt = "\n".join( [ - f"{role_to_name(message['role'], model_config)} {message['content']}" + f"{role_to_name(message['role'], model_config)} {message.get('content', '')}" for message in messages ] ) @@ -50,10 +57,17 @@ def make_prompt(messages: List[Message], model_config: LLaMAModelConfig) -> str: class LLaMACompletionProvider(CompletionProvider): def complete( - self, messages: List[Message], args: dict, stream: bool = False - ) -> Iterator[str]: + self, + messages: List[Message], + args: dict, + stream: bool = False, + enable_code_execution: bool = False, + ) -> Iterator[Completion]: assert LLAMA_MODELS, "LLaMA models not initialized" + if enable_code_execution: + raise ValueError("Code execution is not supported by LLaMA models") + model_config = LLAMA_MODELS[args["model"]] with suppress_stderr(): @@ -64,7 +78,6 @@ def complete( use_mlock=True, ) prompt = make_prompt(messages, model_config) - print(prompt) extra_args = {} if "temperature" in args: @@ -80,11 +93,19 @@ def complete( echo=False, **extra_args, ) + if stream: - for x in cast(Iterator[CompletionChunk], gen): - yield x["choices"][0]["text"] + + def completion_iter() -> Iterator[str]: + for data in cast(Iterator[CompletionChunk], gen): + yield data["choices"][0]["text"] + + for x in make_completion_iter(completion_iter()): + yield x else: - yield cast(Completion, gen)["choices"][0]["text"] + yield make_completion( + cast(LlamaCompletion, gen)["choices"][0]["text"], finish_reason="stop" + ) # https://stackoverflow.com/a/50438156 diff --git a/gptcli/logging.py b/gptcli/logging.py index b5ac84c..1061328 100644 --- a/gptcli/logging.py +++ b/gptcli/logging.py @@ -21,4 +21,6 @@ def on_error(self, e: Exception): self.logger.exception(e) def on_chat_message(self, message: Message): - self.logger.info(f"{message['role']}: {message['content']}") + self.logger.info( + f"{message['role']}: '{message.get('content')}', function_call={message.get('function_call')}" + ) diff --git a/gptcli/openai.py b/gptcli/openai.py index 5a1a58f..e2b9e21 100644 --- a/gptcli/openai.py +++ b/gptcli/openai.py @@ -1,20 +1,59 @@ +import logging from typing import Any, Iterator, List, cast import openai import tiktoken -from gptcli.completion import CompletionProvider, Message +from gptcli.completion import Completion, CompletionProvider, Message + +FUNCTIONS_SCHEMA = [ + { + "name": "python", + "description": "Evaluate an arbitrary Python snippet", + "parameters": { + "type": "object", + "properties": { + "source": { + "type": "string", + "description": "The Python code to evaluate", + }, + }, + "required": ["source"], + }, + }, + { + "name": "pip_install", + "description": "Install a Python package. The kernel will be restarted automatically after the package is installed.", + "parameters": { + "type": "object", + "properties": { + "package": { + "type": "string", + "description": "The package to install", + }, + }, + "required": ["package"], + }, + }, +] class OpenAICompletionProvider(CompletionProvider): def complete( - self, messages: List[Message], args: dict, stream: bool = False - ) -> Iterator[str]: + self, + messages: List[Message], + args: dict, + stream: bool = False, + enable_code_execution: bool = False, + ) -> Iterator[Completion]: kwargs = {} if "temperature" in args: kwargs["temperature"] = args["temperature"] if "top_p" in args: kwargs["top_p"] = args["top_p"] + if enable_code_execution: + kwargs["functions"] = FUNCTIONS_SCHEMA + response_iter = cast( Any, openai.ChatCompletion.create( @@ -28,28 +67,59 @@ def complete( if stream: for response in response_iter: next_choice = response["choices"][0] - if ( - next_choice["finish_reason"] is None - and "content" in next_choice["delta"] - ): - yield next_choice["delta"]["content"] + yield next_choice else: next_choice = response_iter["choices"][0] - yield next_choice["message"]["content"] + next_choice["delta"] = next_choice["message"] + yield next_choice -def num_tokens_from_messages_openai(messages: List[Message], model: str) -> int: - encoding = tiktoken.encoding_for_model(model) +def num_tokens_from_messages_openai(messages, model="gpt-3.5-turbo-0613"): + """Return the number of tokens used by a list of messages.""" + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + if model in { + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + }: + tokens_per_message = 3 + tokens_per_name = 1 + elif model == "gpt-3.5-turbo-0301": + tokens_per_message = ( + 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + ) + tokens_per_name = -1 # if there's a name, the role is omitted + elif "gpt-3.5-turbo" in model: + return num_tokens_from_messages_openai(messages, model="gpt-3.5-turbo-0613") + elif "gpt-4" in model: + return num_tokens_from_messages_openai(messages, model="gpt-4-0613") + else: + raise NotImplementedError( + f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" + ) num_tokens = 0 for message in messages: - # every message follows {role/name}\n{content}\n - num_tokens += 4 + num_tokens += tokens_per_message for key, value in message.items(): - assert isinstance(value, str) + logging.debug(f"key: {key}, value: {value}") + if key == "function_call": + # TODO: is this correct? + value = f"{value['name']}({value['arguments']})" + if key == "content": + # TODO: content is None for some messages with function calls + if value is None: + continue num_tokens += len(encoding.encode(value)) - if key == "name": # if there's a name, the role is omitted - num_tokens += -1 # role is always required and always 1 token - num_tokens += 2 # every reply is primed with assistant + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens diff --git a/gptcli/session.py b/gptcli/session.py index 2ec8394..5b9324d 100644 --- a/gptcli/session.py +++ b/gptcli/session.py @@ -1,16 +1,21 @@ from abc import abstractmethod +import json +import traceback from typing_extensions import TypeGuard from gptcli.assistant import Assistant -from gptcli.completion import Message, ModelOverrides +from gptcli.completion import FunctionCall, Message, ModelOverrides, merge_dicts from openai import InvalidRequestError, OpenAIError -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple class ResponseStreamer: def __enter__(self) -> "ResponseStreamer": return self - def on_next_token(self, token: str): + def on_message_delta(self, message_delta: Message): + pass + + def on_function_result(self, result: dict): pass def __exit__(self, *args): @@ -36,6 +41,9 @@ def response_streamer(self) -> ResponseStreamer: def on_chat_message(self, message: Message): pass + def on_function_call(self, function_name: str, **kwargs) -> Optional[dict]: + pass + def on_chat_response( self, messages: List[Message], response: Message, overrides: ModelOverrides ): @@ -87,35 +95,102 @@ def _rerun(self): _, args = self.user_prompts[-1] self._respond(args) - def _respond(self, args: ModelOverrides) -> bool: - """ - Respond to the user's input and return whether the assistant's response was saved. - """ - next_response: str = "" - try: - completion_iter = self.assistant.complete_chat( - self.messages, override_params=args - ) + def _completion(self, args: ModelOverrides): + next_message: Message = { + "role": "", + } + finish_reason = None + + completion_iter = self.assistant.complete_chat( + self.messages, override_params=args + ) + try: with self.listener.response_streamer() as stream: - for response in completion_iter: - next_response += response - stream.on_next_token(response) + for completion in completion_iter: + next_message = merge_dicts(next_message, completion["delta"]) + stream.on_message_delta(completion["delta"]) + finish_reason = completion["finish_reason"] except KeyboardInterrupt: # If the user interrupts the chat completion, we'll just return what we have so far pass - except InvalidRequestError as e: - self.listener.on_error(e) - return False - except OpenAIError as e: - self.listener.on_error(e) - return True - next_message: Message = {"role": "assistant", "content": next_response} - self.listener.on_chat_message(next_message) - self.listener.on_chat_response(self.messages, next_message, args) + if "content" not in next_message: + next_message["content"] = None + + return { + "message": next_message, + "finish_reason": finish_reason, + } + + def _handle_function_call(self, function_call: FunctionCall) -> Message: + function_name = function_call.get("name", "null") + + function_result = None + + try: + arguments = function_call.get("arguments", "{}") + if arguments.startswith("{"): + function_arguments = json.loads(arguments) + else: + # HACK: gpt-3.5-turbo sometimes returns a string instead of a dict for python calls + function_arguments = { + "source": function_call.get("arguments", ""), + } + + function_result = self.listener.on_function_call( + function_name, **function_arguments + ) + except Exception: + function_result = { + "text/plain": f"Exception occurred:\n\n```{traceback.format_exc()}```" + } + + content = "" + if function_result: + with self.listener.response_streamer() as stream: + stream.on_function_result(function_result) + content = function_result.get("text/plain") + + return { + "role": "function", + "name": function_name, + "content": content, + } + + def _respond(self, args: ModelOverrides) -> bool: + """ + Respond to the user's input and return whether the assistant's response was saved. + """ + finish_reason: Optional[str] = None + + while finish_reason != "stop": + try: + completion = self._completion(args) + next_message = completion["message"] + finish_reason = completion["finish_reason"] + + if finish_reason is None: + # If the user interrupts the chat completion, we'll stop here + break + + except InvalidRequestError as e: + self.listener.on_error(e) + return False + except OpenAIError as e: + self.listener.on_error(e) + return True + + self.messages = self.messages + [next_message] + self.listener.on_chat_message(next_message) + self.listener.on_chat_response(self.messages, next_message, args) + if finish_reason == "function_call": + function_message = self._handle_function_call( + next_message["function_call"] + ) + self.messages = self.messages + [function_message] + self.listener.on_chat_message(function_message) - self.messages = self.messages + [next_message] return True def _validate_args(self, args: Dict[str, Any]) -> TypeGuard[ModelOverrides]: diff --git a/gptcli/shell.py b/gptcli/shell.py index 83bc681..a312054 100644 --- a/gptcli/shell.py +++ b/gptcli/shell.py @@ -14,8 +14,9 @@ def simple_response(assistant: Assistant, prompt: str, stream: bool) -> None: result = "" try: for response in response_iter: - result += response - sys.stdout.write(response) + delta = response["delta"].get("content") or "" + result += delta + sys.stdout.write(delta) except KeyboardInterrupt: pass finally: @@ -28,7 +29,7 @@ def execute(assistant: Assistant, prompt: str) -> None: messages.append({"role": "user", "content": prompt}) logging.info("User: %s", prompt) response_iter = assistant.complete_chat(messages, stream=False) - result = next(response_iter) + result = next(response_iter)["delta"].get("content") or "" logging.info("Assistant: %s", result) with tempfile.NamedTemporaryFile(mode="w", prefix="gptcli-", delete=False) as f: @@ -51,5 +52,5 @@ def execute(assistant: Assistant, prompt: str) -> None: shell = os.environ.get("SHELL", "/bin/bash") logging.info(f"Executing: {command}") - print(f"Executing:\n{command}") + print(f"Executing:\n{command}", file=sys.stderr) subprocess.run([shell, f.name]) diff --git a/requirements.txt b/requirements.txt index a0744b4..a650e97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,8 @@ prompt-toolkit==3.0.38 pytest==7.3.1 PyYAML==6.0 rich==13.3.2 -tiktoken==0.3.3 +tiktoken==0.4.0 tokenizers==0.13.3 typing_extensions==4.5.0 +jupyter-client==8.2.0 +imgcat==0.5.0 diff --git a/requirements_docker.txt b/requirements_docker.txt new file mode 100644 index 0000000..50e1ac4 --- /dev/null +++ b/requirements_docker.txt @@ -0,0 +1,7 @@ +ipython==8.14.0 +ipykernel==6.23.2 +numpy==1.25.0 +matplotlib==3.7.1 +scipy==1.10.1 +scikit-learn==1.2.2 +pandas==2.0.2