Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI code interpreter (draft) #37

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix non-openai models
  • Loading branch information
kharvd committed Jun 24, 2023
commit 7dfd8b9a14b66756ac579ad962772fe7a11df21f
4 changes: 0 additions & 4 deletions gpt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
#!/usr/bin/env python
import os
import subprocess
import tempfile
import traceback
from typing import cast
import openai
import random
import argparse
import sys
import logging
Expand Down
46 changes: 34 additions & 12 deletions gptcli/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import logging
import os
from typing import Iterator, List
import anthropic

from gptcli.completion import CompletionProvider, Message
from gptcli.completion import (
Completion,
CompletionProvider,
Message,
make_completion,
make_completion_iter,
)

api_key = os.environ.get("ANTHROPIC_API_KEY")

Expand All @@ -25,16 +32,26 @@ def role_to_name(role: str) -> str:

def make_prompt(messages: List[Message]) -> str:
prompt = "\n".join(
[f"{role_to_name(message['role'])}{message['content']}" for message in messages]
[
f"{role_to_name(message['role'])}{message.get('content', '')}"
for message in messages
]
)
prompt += f"{role_to_name('assistant')}"
return prompt


class AnthropicCompletionProvider(CompletionProvider):
def complete(
self, messages: List[Message], args: dict, stream: bool = False
) -> Iterator[str]:
self,
messages: List[Message],
args: dict,
stream: bool = False,
enable_code_execution: bool = False,
) -> Iterator[Completion]:
if enable_code_execution:
raise ValueError("Code execution is not supported by Anthropic models")

kwargs = {
"prompt": make_prompt(messages),
"stop_sequences": [anthropic.HUMAN_PROMPT],
Expand All @@ -49,14 +66,19 @@ def complete(
client = get_client()
if stream:
response = client.completion_stream(**kwargs)
else:
response = [client.completion(**kwargs)]

prev_completion = ""
for data in response:
next_completion = data["completion"]
yield next_completion[len(prev_completion) :]
prev_completion = next_completion
def content_iter() -> Iterator[str]:
prev_completion = ""
for data in response:
next_completion = data["completion"]
yield next_completion[len(prev_completion) :]
prev_completion = next_completion

for x in make_completion_iter(content_iter()):
yield x
else:
response = client.completion(**kwargs)
yield make_completion(response["completion"], finish_reason="stop")


def num_tokens_from_messages_anthropic(messages: List[Message], model: str) -> int:
Expand All @@ -65,4 +87,4 @@ def num_tokens_from_messages_anthropic(messages: List[Message], model: str) -> i


def num_tokens_from_completion_anthropic(message: Message, model: str) -> int:
return anthropic.count_tokens(message["content"])
return anthropic.count_tokens(message.get("content") or "")
2 changes: 1 addition & 1 deletion gptcli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def print(self, message_delta: Message):
self.live.refresh()
else:
self.console.print(
Text(message_delta.get("content", ""), style="green"), end=""
Text(message_delta.get("content") or "", style="green"), end=""
)

def __exit__(self, *args):
Expand Down
30 changes: 29 additions & 1 deletion gptcli/completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractmethod
import logging
from typing import Iterator, List, Optional, TypedDict
from typing_extensions import Required

Expand All @@ -15,7 +16,7 @@ class Message(TypedDict, total=False):
function_call: FunctionCall


def merge_dicts(a: dict, b: dict):
def merge_dicts(a, b):
"""
Given two nested dicts with string values, merge dict `b` into dict `a`, concatenating
string values.
Expand Down Expand Up @@ -45,6 +46,33 @@ class Completion(TypedDict):
finish_reason: Optional[str]


def make_completion(
content_delta: str,
role: str = "assistant",
finish_reason: Optional[str] = None,
) -> Completion:
delta: Message = {
"role": role,
"content": content_delta,
}
return {
"delta": delta,
"finish_reason": finish_reason,
}


def make_completion_iter(
content_iter: Iterator[str],
role: str = "assistant",
finish_reason: Optional[str] = "stop",
) -> Iterator[Completion]:
logging.debug("make_completion_iter")
yield make_completion("", role=role)
for content in content_iter:
yield make_completion(content, role="")
yield make_completion("", role="", finish_reason=finish_reason)


class CompletionProvider:
@abstractmethod
def complete(
Expand Down
24 changes: 18 additions & 6 deletions gptcli/google.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Iterator, List
import google.generativeai as genai
from gptcli.completion import CompletionProvider, Message
from gptcli.completion import Completion, CompletionProvider, Message, make_completion


def role_to_author(role: str) -> str:
Expand All @@ -14,11 +14,16 @@ def role_to_author(role: str) -> str:

def make_prompt(messages: List[Message]):
system_messages = [
message["content"] for message in messages if message["role"] == "system"
message.get("content") or ""
for message in messages
if message["role"] == "system"
]
context = "\n".join(system_messages)
prompt = [
{"author": role_to_author(message["role"]), "content": message["content"]}
{
"author": role_to_author(message["role"]),
"content": message.get("content", ""),
}
for message in messages
if message["role"] != "system"
]
Expand All @@ -27,8 +32,15 @@ def make_prompt(messages: List[Message]):

class GoogleCompletionProvider(CompletionProvider):
def complete(
self, messages: List[Message], args: dict, stream: bool = False
) -> Iterator[str]:
self,
messages: List[Message],
args: dict,
stream: bool = False,
enable_code_execution: bool = False,
) -> Iterator[Completion]:
if enable_code_execution:
raise ValueError("Code execution is not supported by Google models")

context, prompt = make_prompt(messages)
kwargs = {
"context": context,
Expand All @@ -40,4 +52,4 @@ def complete(
kwargs["top_p"] = args["top_p"]

response = genai.chat(**kwargs)
yield response.last
yield make_completion(response.last, finish_reason="stop")
39 changes: 30 additions & 9 deletions gptcli/llama.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import logging
import os
import sys
from typing import Iterator, List, Optional, TypedDict, cast
from llama_cpp import Completion, CompletionChunk, Llama
from llama_cpp import Completion as LlamaCompletion, CompletionChunk, Llama

from gptcli.completion import CompletionProvider, Message
from gptcli.completion import (
CompletionProvider,
Message,
Completion,
make_completion,
make_completion_iter,
)


class LLaMAModelConfig(TypedDict):
Expand Down Expand Up @@ -40,7 +47,7 @@ def role_to_name(role: str, model_config: LLaMAModelConfig) -> str:
def make_prompt(messages: List[Message], model_config: LLaMAModelConfig) -> str:
prompt = "\n".join(
[
f"{role_to_name(message['role'], model_config)} {message['content']}"
f"{role_to_name(message['role'], model_config)} {message.get('content', '')}"
for message in messages
]
)
Expand All @@ -50,10 +57,17 @@ def make_prompt(messages: List[Message], model_config: LLaMAModelConfig) -> str:

class LLaMACompletionProvider(CompletionProvider):
def complete(
self, messages: List[Message], args: dict, stream: bool = False
) -> Iterator[str]:
self,
messages: List[Message],
args: dict,
stream: bool = False,
enable_code_execution: bool = False,
) -> Iterator[Completion]:
assert LLAMA_MODELS, "LLaMA models not initialized"

if enable_code_execution:
raise ValueError("Code execution is not supported by LLaMA models")

model_config = LLAMA_MODELS[args["model"]]

with suppress_stderr():
Expand All @@ -64,7 +78,6 @@ def complete(
use_mlock=True,
)
prompt = make_prompt(messages, model_config)
print(prompt)

extra_args = {}
if "temperature" in args:
Expand All @@ -80,11 +93,19 @@ def complete(
echo=False,
**extra_args,
)

if stream:
for x in cast(Iterator[CompletionChunk], gen):
yield x["choices"][0]["text"]

def completion_iter() -> Iterator[str]:
for data in cast(Iterator[CompletionChunk], gen):
yield data["choices"][0]["text"]

for x in make_completion_iter(completion_iter()):
yield x
else:
yield cast(Completion, gen)["choices"][0]["text"]
yield make_completion(
cast(LlamaCompletion, gen)["choices"][0]["text"], finish_reason="stop"
)


# https://stackoverflow.com/a/50438156
Expand Down
2 changes: 2 additions & 0 deletions gptcli/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Any, Iterator, List, cast
import openai
import tiktoken
Expand Down Expand Up @@ -69,6 +70,7 @@ def complete(
yield next_choice
else:
next_choice = response_iter["choices"][0]
next_choice["delta"] = next_choice["message"]
yield next_choice


Expand Down
7 changes: 1 addition & 6 deletions gptcli/session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from abc import abstractmethod
import logging
import json
import traceback
import base64
from typing_extensions import TypeGuard
from gptcli.assistant import Assistant
from gptcli.completion import FunctionCall, Message, ModelOverrides, merge_dicts
Expand Down Expand Up @@ -98,7 +96,7 @@ def _rerun(self):
self._respond(args)

def _completion(self, args: ModelOverrides):
next_message = {
next_message: Message = {
"role": "",
}
finish_reason = None
Expand All @@ -113,9 +111,6 @@ def _completion(self, args: ModelOverrides):
next_message = merge_dicts(next_message, completion["delta"])
stream.on_message_delta(completion["delta"])
finish_reason = completion["finish_reason"]

if next_message.get("function_call") is not None:
logging.debug(f"Function call: {next_message['function_call']}")
except KeyboardInterrupt:
# If the user interrupts the chat completion, we'll just return what we have so far
pass
Expand Down
9 changes: 5 additions & 4 deletions gptcli/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ def simple_response(assistant: Assistant, prompt: str, stream: bool) -> None:
result = ""
try:
for response in response_iter:
result += response
sys.stdout.write(response)
delta = response["delta"].get("content") or ""
result += delta
sys.stdout.write(delta)
except KeyboardInterrupt:
pass
finally:
Expand All @@ -28,7 +29,7 @@ def execute(assistant: Assistant, prompt: str) -> None:
messages.append({"role": "user", "content": prompt})
logging.info("User: %s", prompt)
response_iter = assistant.complete_chat(messages, stream=False)
result = next(response_iter)
result = next(response_iter)["delta"].get("content") or ""
logging.info("Assistant: %s", result)

with tempfile.NamedTemporaryFile(mode="w", prefix="gptcli-", delete=False) as f:
Expand All @@ -51,5 +52,5 @@ def execute(assistant: Assistant, prompt: str) -> None:
shell = os.environ.get("SHELL", "/bin/bash")

logging.info(f"Executing: {command}")
print(f"Executing:\n{command}")
print(f"Executing:\n{command}", file=sys.stderr)
subprocess.run([shell, f.name])