Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

start adding typing #59

Merged
merged 1 commit into from
Mar 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ jobs:
pip install nox-poetry
- name: Run lint
run: nox -s lint
- name: Type checking
run: |
poetry run mypy genai/
39 changes: 11 additions & 28 deletions genai/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Creates user and system messages as context for ChatGPT, using the history of the current IPython session.
"""
from typing import Any, Dict, List

from IPython.core.interactiveshell import InteractiveShell

try:
import pandas as pd
Expand All @@ -12,43 +12,33 @@
from . import tokens


def craft_message(text, role="user"):
def craft_message(text: str, role: str = "user") -> Dict[str, str]:
return {"content": text, "role": role}


def craft_user_message(code):
def craft_user_message(code: str) -> Dict[str, str]:
return craft_message(code, "user")


def repr_genai_pandas(output):
def repr_genai_pandas(output: Any) -> str:
if not PANDAS_INSTALLED:
return repr(output)

if isinstance(output, pd.DataFrame):
# to_markdown() does not use the max_rows and max_columns options
# so we have to truncate the dataframe ourselves

num_columns = min(pd.options.display.max_columns, output.shape[1])
num_rows = min(pd.options.display.max_rows, output.shape[0])

sampled = output.sample(num_columns, axis=1).sample(num_rows, axis=0)

return sampled.to_markdown()

if isinstance(output, pd.Series):
# Similar truncation for series
num_rows = min(pd.options.display.max_rows, output.shape[0])
sampled = output.sample(num_rows)
return sampled.to_markdown()

return repr(output)


def repr_genai(output):
'''Compute a GPT-3.5 friendly representation of the output of a cell.

For DataFrames and Series this means Markdown.
'''
def repr_genai(output: Any) -> str:
if not PANDAS_INSTALLED:
return repr(output)

Expand All @@ -58,12 +48,10 @@ def repr_genai(output):
return repr_genai_pandas(output)


def craft_output_message(output):
"""Craft a message from the output of a cell."""
def craft_output_message(output: Any) -> Dict[str, str]:
return craft_message(repr_genai(output), "system")


# tokens to idenfify which cells to ignore based on the first line
ignore_tokens = [
"# genai:ignore",
"#ignore",
Expand All @@ -77,12 +65,10 @@ def craft_output_message(output):
]


def get_historical_context(ipython, num_messages=5, model="gpt-3.5-turbo-0301"):
"""Create a series of messages to use as context for ChatGPT."""
def get_historical_context(
ipython: 'InteractiveShell', num_messages: int = 5, model: str = "gpt-3.5-turbo-0301"
) -> List[Dict[str, str]]:
raw_inputs = ipython.history_manager.input_hist_raw

# Now filter out any inputs that start with our filters
# This has to keep the input index as the key for the output
inputs = {}
for i, input in enumerate(raw_inputs):
if input is None or input.strip() == "":
Expand All @@ -92,17 +78,14 @@ def get_historical_context(ipython, num_messages=5, model="gpt-3.5-turbo-0301"):
inputs[i] = input

outputs = ipython.history_manager.output_hist

indices = sorted(inputs.keys())
context = []

# We will use the last `num_messages` inputs and outputs to establish context
for index in indices[-num_messages:]:
context.append(craft_user_message(inputs[index]))

if index in outputs:
context.append(craft_output_message(outputs[index]))

context = tokens.trim_messages_to_fit_token_limit(context, model=model)

return context
70 changes: 46 additions & 24 deletions genai/generate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict, Iterator, List, TypedDict

import openai

NOTEBOOK_CREATE_NEXT_CELL_PROCLAMATION = """
Expand All @@ -16,33 +18,61 @@
""".strip() # noqa: E501


def content(completion):
Completion = TypedDict(
"Completion",
{
"choices": List[Dict[str, Any]],
},
)


def content(completion: Completion):
return completion["choices"][0]["message"]["content"]


def deltas(completion):
Delta = TypedDict(
"Delta",
{
"content": str,
},
)


StreamChoice = TypedDict(
"StreamChoice",
{
"delta": Delta,
},
)

StreamCompletion = TypedDict(
"StreamCompletion",
{
"choices": List[StreamChoice],
},
)


def deltas(completion: Iterator[StreamCompletion]) -> Iterator[str]:
for chunk in completion:
delta = chunk["choices"][0]["delta"]
if "content" in delta:
yield delta["content"]


def generate_next_cell(
context, # List[Dict[str, str]]
text,
stream=False,
):
context: List[Dict[str, str]],
text: str,
stream: bool = False,
) -> Iterator[str]:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
# Establish the context in which GPT will respond
{
"role": "system",
"content": NOTEBOOK_CREATE_NEXT_CELL_PROCLAMATION,
},
# In, Out
*context,
# The user code/text
{
"role": "user",
"content": text,
Expand All @@ -58,15 +88,12 @@ def generate_next_cell(


def generate_exception_suggestion(
# The user's code
code,
# The exception with traceback
etype,
evalue,
plaintext_traceback,
stream=False,
):
# Cap our error report at ~1024 characters
code: str,
etype: type,
evalue: BaseException,
plaintext_traceback: str,
stream: bool = False,
) -> Iterator[str]:
error_report = f"{etype.__name__}: {evalue}\n{plaintext_traceback}"

if len(error_report) > 1024:
Expand All @@ -75,21 +102,16 @@ def generate_exception_suggestion(
messages = []

messages.append(
# Establish the context in which GPT will respond with role: assistant
{
"role": "system",
"content": NOTEBOOK_ERROR_DIAGNOSER_PROCLAMATION,
},
)

if code is not None:
messages.append(
# The user sent code
{"role": "user", "content": code}
)
messages.append({"role": "user", "content": code})

messages.append(
# The system wrote back with the error
{
"role": "system",
"content": error_report,
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[mypy]
ignore_missing_imports = True
warn_unused_configs = True
49 changes: 48 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ bump2version = "^1.0.1"

[tool.poetry.group.dev.dependencies]
pandas = "^1.5.3"
mypy = "^1.1.1"

[build-system]
requires = ["poetry-core>=1.0.1"]
Expand Down