Skip to content

Commit

Permalink
start adding typing
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Mar 19, 2023
1 parent bbb9dfe commit c0dcf2c
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 53 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ jobs:
pip install nox-poetry
- name: Run lint
run: nox -s lint
- name: Type checking
run: |
poetry run mypy genai/
39 changes: 11 additions & 28 deletions genai/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Creates user and system messages as context for ChatGPT, using the history of the current IPython session.
"""
from typing import Any, Dict, List

from IPython.core.interactiveshell import InteractiveShell

try:
import pandas as pd
Expand All @@ -12,43 +12,33 @@
from . import tokens


def craft_message(text, role="user"):
def craft_message(text: str, role: str = "user") -> Dict[str, str]:
return {"content": text, "role": role}


def craft_user_message(code):
def craft_user_message(code: str) -> Dict[str, str]:
return craft_message(code, "user")


def repr_genai_pandas(output):
def repr_genai_pandas(output: Any) -> str:
if not PANDAS_INSTALLED:
return repr(output)

if isinstance(output, pd.DataFrame):
# to_markdown() does not use the max_rows and max_columns options
# so we have to truncate the dataframe ourselves

num_columns = min(pd.options.display.max_columns, output.shape[1])
num_rows = min(pd.options.display.max_rows, output.shape[0])

sampled = output.sample(num_columns, axis=1).sample(num_rows, axis=0)

return sampled.to_markdown()

if isinstance(output, pd.Series):
# Similar truncation for series
num_rows = min(pd.options.display.max_rows, output.shape[0])
sampled = output.sample(num_rows)
return sampled.to_markdown()

return repr(output)


def repr_genai(output):
'''Compute a GPT-3.5 friendly representation of the output of a cell.
For DataFrames and Series this means Markdown.
'''
def repr_genai(output: Any) -> str:
if not PANDAS_INSTALLED:
return repr(output)

Expand All @@ -58,12 +48,10 @@ def repr_genai(output):
return repr_genai_pandas(output)


def craft_output_message(output):
"""Craft a message from the output of a cell."""
def craft_output_message(output: Any) -> Dict[str, str]:
return craft_message(repr_genai(output), "system")


# tokens to idenfify which cells to ignore based on the first line
ignore_tokens = [
"# genai:ignore",
"#ignore",
Expand All @@ -77,12 +65,10 @@ def craft_output_message(output):
]


def get_historical_context(ipython, num_messages=5, model="gpt-3.5-turbo-0301"):
"""Create a series of messages to use as context for ChatGPT."""
def get_historical_context(
ipython: 'InteractiveShell', num_messages: int = 5, model: str = "gpt-3.5-turbo-0301"
) -> List[Dict[str, str]]:
raw_inputs = ipython.history_manager.input_hist_raw

# Now filter out any inputs that start with our filters
# This has to keep the input index as the key for the output
inputs = {}
for i, input in enumerate(raw_inputs):
if input is None or input.strip() == "":
Expand All @@ -92,17 +78,14 @@ def get_historical_context(ipython, num_messages=5, model="gpt-3.5-turbo-0301"):
inputs[i] = input

outputs = ipython.history_manager.output_hist

indices = sorted(inputs.keys())
context = []

# We will use the last `num_messages` inputs and outputs to establish context
for index in indices[-num_messages:]:
context.append(craft_user_message(inputs[index]))

if index in outputs:
context.append(craft_output_message(outputs[index]))

context = tokens.trim_messages_to_fit_token_limit(context, model=model)

return context
70 changes: 46 additions & 24 deletions genai/generate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict, Iterator, List, TypedDict

import openai

NOTEBOOK_CREATE_NEXT_CELL_PROCLAMATION = """
Expand All @@ -16,33 +18,61 @@
""".strip() # noqa: E501


def content(completion):
Completion = TypedDict(
"Completion",
{
"choices": List[Dict[str, Any]],
},
)


def content(completion: Completion):
return completion["choices"][0]["message"]["content"]


def deltas(completion):
Delta = TypedDict(
"Delta",
{
"content": str,
},
)


StreamChoice = TypedDict(
"StreamChoice",
{
"delta": Delta,
},
)

StreamCompletion = TypedDict(
"StreamCompletion",
{
"choices": List[StreamChoice],
},
)


def deltas(completion: Iterator[StreamCompletion]) -> Iterator[str]:
for chunk in completion:
delta = chunk["choices"][0]["delta"]
if "content" in delta:
yield delta["content"]


def generate_next_cell(
context, # List[Dict[str, str]]
text,
stream=False,
):
context: List[Dict[str, str]],
text: str,
stream: bool = False,
) -> Iterator[str]:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
# Establish the context in which GPT will respond
{
"role": "system",
"content": NOTEBOOK_CREATE_NEXT_CELL_PROCLAMATION,
},
# In, Out
*context,
# The user code/text
{
"role": "user",
"content": text,
Expand All @@ -58,15 +88,12 @@ def generate_next_cell(


def generate_exception_suggestion(
# The user's code
code,
# The exception with traceback
etype,
evalue,
plaintext_traceback,
stream=False,
):
# Cap our error report at ~1024 characters
code: str,
etype: type,
evalue: BaseException,
plaintext_traceback: str,
stream: bool = False,
) -> Iterator[str]:
error_report = f"{etype.__name__}: {evalue}\n{plaintext_traceback}"

if len(error_report) > 1024:
Expand All @@ -75,21 +102,16 @@ def generate_exception_suggestion(
messages = []

messages.append(
# Establish the context in which GPT will respond with role: assistant
{
"role": "system",
"content": NOTEBOOK_ERROR_DIAGNOSER_PROCLAMATION,
},
)

if code is not None:
messages.append(
# The user sent code
{"role": "user", "content": code}
)
messages.append({"role": "user", "content": code})

messages.append(
# The system wrote back with the error
{
"role": "system",
"content": error_report,
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[mypy]
ignore_missing_imports = True
warn_unused_configs = True
49 changes: 48 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ bump2version = "^1.0.1"

[tool.poetry.group.dev.dependencies]
pandas = "^1.5.3"
mypy = "^1.1.1"

[build-system]
requires = ["poetry-core>=1.0.1"]
Expand Down

0 comments on commit c0dcf2c

Please sign in to comment.