Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI code interpreter (draft) #37

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.vscode/
.vscode/

gpt.yml
23 changes: 23 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
FROM python:3.10-bullseye
WORKDIR /app

COPY requirements.txt ./
RUN --mount=type=cache,target=/root/.cache/pip pip install -r requirements.txt

COPY requirements_docker.txt ./
RUN --mount=type=cache,target=/root/.cache/pip pip install -r requirements_docker.txt

COPY . .

RUN mkdir -p /mnt/output

RUN adduser --disabled-password gpt
USER gpt
RUN mkdir -p $HOME/.config/gpt-cli
RUN cp /app/gpt.yml $HOME/.config/gpt-cli/gpt.yml


WORKDIR /mnt/output

ENV GPTCLI_ALLOW_CODE_EXECUTION=1
ENTRYPOINT ["python", "/app/gpt.py"]
19 changes: 17 additions & 2 deletions gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
choose_config_file,
read_yaml_config,
)
from gptcli.interpreter import CodeInterpreterListener
from gptcli.llama import init_llama_models
from gptcli.logging import LoggingChatListener
from gptcli.cost import PriceChatListener
Expand Down Expand Up @@ -89,6 +90,12 @@ def parse_args(config: GptCliConfig):
default=config.log_file,
help="The file to write logs to",
)
parser.add_argument(
"--history_file",
type=str,
default=config.history_file,
help="The file to write chat history to",
)
parser.add_argument(
"--log_level",
type=str,
Expand Down Expand Up @@ -215,16 +222,24 @@ def __init__(self, assistant: Assistant, markdown: bool, show_price: bool):
if show_price:
listeners.append(PriceChatListener(assistant))

if os.environ.get("GPTCLI_ALLOW_CODE_EXECUTION") == "1":
listeners.append(CodeInterpreterListener("python"))

listener = CompositeChatListener(listeners)
super().__init__(assistant, listener)
super().__init__(
assistant,
listener,
)


def run_interactive(args, assistant):
logger.info("Starting a new chat session. Assistant config: %s", assistant.config)
session = CLIChatSession(
assistant=assistant, markdown=args.markdown, show_price=args.show_price
)
history_filename = os.path.expanduser("~/.config/gpt-cli/history")
history_filename = args.history_file or os.path.expanduser(
"~/.config/gpt-cli/history"
)
os.makedirs(os.path.dirname(history_filename), exist_ok=True)
input_provider = CLIUserInputProvider(history_filename=history_filename)
session.loop(input_provider)
Expand Down
12 changes: 12 additions & 0 deletions gpt.yml.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
markdown: True
openai_api_key: <YOUR_OPENAI_API_KEY_HERE>
log_file: /mnt/gpt.log
log_level: DEBUG
history_file: /mnt/history
assistants:
python:
model: gpt-4-0613
enable_code_execution: True
messages:
- { role: "system", content: "You are a helpful assistant. You have access to a Python environment. You can install missing packages. You have access to the internet. The user can see the code you are executing and its output: do not repeat them to the user verbatim. Pre-installed packages: numpy, matplotlib, scipy, scikit-learn, pandas, ipython, ipykernel." }

46 changes: 34 additions & 12 deletions gptcli/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import logging
import os
from typing import Iterator, List
import anthropic

from gptcli.completion import CompletionProvider, Message
from gptcli.completion import (
Completion,
CompletionProvider,
Message,
make_completion,
make_completion_iter,
)

api_key = os.environ.get("ANTHROPIC_API_KEY")

Expand All @@ -25,16 +32,26 @@ def role_to_name(role: str) -> str:

def make_prompt(messages: List[Message]) -> str:
prompt = "\n".join(
[f"{role_to_name(message['role'])}{message['content']}" for message in messages]
[
f"{role_to_name(message['role'])}{message.get('content', '')}"
for message in messages
]
)
prompt += f"{role_to_name('assistant')}"
return prompt


class AnthropicCompletionProvider(CompletionProvider):
def complete(
self, messages: List[Message], args: dict, stream: bool = False
) -> Iterator[str]:
self,
messages: List[Message],
args: dict,
stream: bool = False,
enable_code_execution: bool = False,
) -> Iterator[Completion]:
if enable_code_execution:
raise ValueError("Code execution is not supported by Anthropic models")

kwargs = {
"prompt": make_prompt(messages),
"stop_sequences": [anthropic.HUMAN_PROMPT],
Expand All @@ -49,14 +66,19 @@ def complete(
client = get_client()
if stream:
response = client.completion_stream(**kwargs)
else:
response = [client.completion(**kwargs)]

prev_completion = ""
for data in response:
next_completion = data["completion"]
yield next_completion[len(prev_completion) :]
prev_completion = next_completion
def content_iter() -> Iterator[str]:
prev_completion = ""
for data in response:
next_completion = data["completion"]
yield next_completion[len(prev_completion) :]
prev_completion = next_completion

for x in make_completion_iter(content_iter()):
yield x
else:
response = client.completion(**kwargs)
yield make_completion(response["completion"], finish_reason="stop")


def num_tokens_from_messages_anthropic(messages: List[Message], model: str) -> int:
Expand All @@ -65,4 +87,4 @@ def num_tokens_from_messages_anthropic(messages: List[Message], model: str) -> i


def num_tokens_from_completion_anthropic(message: Message, model: str) -> int:
return anthropic.count_tokens(message["content"])
return anthropic.count_tokens(message.get("content") or "")
15 changes: 12 additions & 3 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import platform
from typing import Any, Dict, Iterator, Optional, TypedDict, List

from gptcli.completion import CompletionProvider, ModelOverrides, Message
from gptcli.completion import Completion, CompletionProvider, ModelOverrides, Message
from gptcli.google import GoogleCompletionProvider
from gptcli.llama import LLaMACompletionProvider
from gptcli.openai import OpenAICompletionProvider
Expand All @@ -16,12 +16,14 @@ class AssistantConfig(TypedDict, total=False):
model: str
temperature: float
top_p: float
enable_code_execution: bool


CONFIG_DEFAULTS = {
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"top_p": 1.0,
"enable_code_execution": False,
}

DEFAULT_ASSISTANTS: Dict[str, AssistantConfig] = {
Expand Down Expand Up @@ -89,7 +91,7 @@ def init_messages(self) -> List[Message]:
return self.config.get("messages", [])[:]

def supported_overrides(self) -> List[str]:
return ["model", "temperature", "top_p"]
return ["model", "temperature", "top_p", "enable_code_execution"]

def _param(self, param: str, override_params: ModelOverrides) -> Any:
# If the param is in the override_params, use that value
Expand All @@ -101,9 +103,15 @@ def _param(self, param: str, override_params: ModelOverrides) -> Any:

def complete_chat(
self, messages, override_params: ModelOverrides = {}, stream: bool = True
) -> Iterator[str]:
) -> Iterator[Completion]:
model = self._param("model", override_params)
completion_provider = get_completion_provider(model)

enable_code_execution = (
bool(self._param("enable_code_execution", override_params))
and os.environ.get("GPTCLI_ALLOW_CODE_EXECUTION") == "1"
)

return completion_provider.complete(
messages,
{
Expand All @@ -112,6 +120,7 @@ def complete_chat(
"top_p": float(self._param("top_p", override_params)),
},
stream,
enable_code_execution,
)


Expand Down
78 changes: 68 additions & 10 deletions gptcli/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import base64
import logging
import re
import json
from imgcat import imgcat
from prompt_toolkit import PromptSession
from prompt_toolkit.history import FileHistory
from openai import OpenAIError, InvalidRequestError
Expand All @@ -9,6 +13,7 @@
from typing import Any, Dict, Optional, Tuple

from rich.text import Text
from gptcli.completion import FunctionCall, Message, merge_dicts
from gptcli.session import (
ALL_COMMANDS,
COMMAND_CLEAR,
Expand All @@ -32,7 +37,7 @@
class StreamingMarkdownPrinter:
def __init__(self, console: Console, markdown: bool):
self.console = console
self.current_text = ""
self.current_message = {}
self.markdown = markdown
self.live: Optional[Live] = None

Expand All @@ -44,14 +49,55 @@ def __enter__(self) -> "StreamingMarkdownPrinter":
self.live.__enter__()
return self

def print(self, text: str):
self.current_text += text
def _format_function_call(self, function_call: FunctionCall) -> str:
text = ""
if function_call.get("name") == "python_eval":
source = function_call.get("arguments", "")
try:
source = json.loads(source).get("source", "")
except:
source = source + '"}'
try:
source = json.loads(source).get("source", "")
except:
source = ""

text += "\n\nExecuting Python code:\n"
text += f"```python\n{source}\n```"
else:
function_name = function_call.get("name", "?")
function_arguments = function_call.get("arguments", {})
text += f"""\n
Calling function:

```
{function_name}({function_arguments})
```"""
return text

def print(self, message_delta: Message):
self.current_message = merge_dicts(self.current_message, message_delta)

if self.markdown:
assert self.live
content = Markdown(self.current_text, style="green")
text = self.current_message.get("content", "")

function_call = self.current_message.get("function_call")
if function_call:
text += self._format_function_call(function_call)

content = Markdown(text, style="green")
self.live.update(content)
self.live.refresh()
else:
text = message_delta.get("content") or ""
function_call = message_delta.get("function_call")
if function_call:
if "name" in function_call:
text += function_call["name"]
if "arguments" in function_call:
text += function_call["arguments"]

self.console.print(Text(text, style="green"), end="")

def __exit__(self, *args):
Expand All @@ -66,17 +112,29 @@ def __init__(self, console: Console, markdown: bool):
self.console = console
self.markdown = markdown
self.printer = StreamingMarkdownPrinter(self.console, self.markdown)
self.first_token = True

def __enter__(self):
self.printer.__enter__()
return self

def on_next_token(self, token: str):
if self.first_token and token.startswith(" "):
token = token[1:]
self.first_token = False
self.printer.print(token)
def on_message_delta(self, message_delta: Message):
self.printer.print(message_delta)

def on_function_result(self, result: dict):
self.console.print(Text("Function result:", style="yellow"))
if "image/png" in result:
image_base64 = result["image/png"]
image_bytes = base64.b64decode(image_base64)
imgcat(image_bytes)
if "text/plain" in result:
text = result["text/plain"]
if self.markdown:
content = Markdown(
f"```\n{text}\n```",
)
else:
content = Text(text, style="yellow")
self.console.print(content)

def __exit__(self, *args):
self.printer.__exit__(*args)
Expand Down
Loading
Loading