Skip to content

Commit

Permalink
inline generation in story textbox
Browse files Browse the repository at this point in the history
  • Loading branch information
socketteer committed Sep 19, 2021
1 parent 10acbaf commit d6809cd
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 109 deletions.
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ ooo what features! wow so cool

# Hotkeys


*Alt hotkeys correspond to Command on Mac*

### File

Expand All @@ -71,21 +71,24 @@ Visualization Settings: `Control-u`

Multimedia dialog: `u`

Tree Info: `i`, `Control-i`
Tree Info: `Control-i`

Node Metadata: `Control+Shift+N`

Run Code: `Control+Shift+B`


### Mode / display

Toggle edit / save edits: `e`, `Control-e`

Toggle story textbox editable: `Control-Shift-e`

Toggle visualize: `j`, `Control-j`

Toggle input box: `Tab`
Toggle bottom pane: `Tab`

Toggle debug box: `Control+Shift+D`
Toggle side pane: `Alt-p`

Toggle show children: `Alt-c`

Expand Down Expand Up @@ -143,6 +146,8 @@ Toggle archive node: `!`

Generate: `g`, `Control-g`

Inline generate: `Alt-i`

Add memory: `Control-m`

View current AI memory: `Control-Shift-m`
Expand Down
5 changes: 5 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@
- transformers
- playground

- fix vis

- store permanent alternative texts
- as ghostparents?

## Other TODO

### bugs
Expand Down
86 changes: 14 additions & 72 deletions components/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,15 +619,15 @@ def build(self, parent):
self.textbox_frame = ttk.Frame(self.pane)
self.pane.add(self.textbox_frame, weight=4)

self.textbox = SmartText(self.textbox_frame, bd=2, height=3, undo=True)
self.textbox = LoomTerminal(self.textbox_frame, bd=2, height=3, undo=True)
self.textbox.pack(side='top', fill='both', expand=True)
self.textboxes.append(self.textbox)
self.textbox.configure(**textbox_config(bg=edit_color()))
#self.textbox.tag_config("generated", font=('Georgia', self.state.preferences['font_size'], 'bold'))

self.textbox.bind("<Key>", self.key_pressed)
self.textbox.bind("<Alt-g>", self.inline_generate)
self.textbox.bind("<Command-g>", self.inline_generate)
self.textbox.bind("<Alt-i>", self.inline_generate)
self.textbox.bind("<Command-i>", self.inline_generate)
self.textbox.bind("<Alt-period>", lambda event: self.insert_inline_completion(step=1))
self.textbox.bind("<Alt-comma>", lambda event: self.insert_inline_completion(step=-1))
self.textbox.bind("<Command-period>", lambda event: self.insert_inline_completion(step=1))
Expand Down Expand Up @@ -658,23 +658,9 @@ def build(self, parent):
self.pane.add(self.completions_frame, weight=1)
self.completion_windows.body(self.completions_frame)

# def replace_selected_text(self, text):
# sel_first = self.textbox.index("sel.first")
# self.textbox.delete("sel.first", "sel.last")
# self.textbox.insert(sel_first, text)

def inline_generate(self, *args):
# get text up to cursor
if self.textbox.tag_ranges("sel"):
self.textbox.fix_selection()
prompt = self.textbox.get("1.0", "sel.first")
selected_range = self.textbox.selected_range()
else:
self.textbox.fix_insertion()
prompt = self.textbox.get("1.0", "insert")
selected_range = [len(prompt), len(prompt)]
settings = self.state.inline_generation_settings
threading.Thread(target=self.call_model_inline, args=(prompt, settings, selected_range)).start()
self.textbox.inline_generate(self.state.inline_generation_settings)

def generate(self, mode='completions', *args):
prompt = self.textbox.get("1.0", "end-1c")
Expand All @@ -690,76 +676,32 @@ def generate(self, mode='completions', *args):

def call_model(self, prompt, settings):
response, error = gen(prompt, settings)
# enable generate button
self.generate_button.configure(state='normal')
self.model_response = response
self.process_logprobs()
self.textbox.process_logprobs()
response_text_list = completions_text(response)
for completion in response_text_list:
self.completion_windows.open_window(completion)

def call_model_prompt(self, prompt, settings):
eval_settings = settings.copy()
eval_settings.update({'max_tokens': 1, 'num_continuations': 1, 'logprobs': 15})
response, error = gen(prompt, eval_settings)
# enable eval button
self.textbox.call_model_prompt(prompt, settings)
self.eval_prompt_button.configure(state='normal')
self.model_response = response
self.process_logprobs()


def call_model_inline(self, prompt, settings, selected_range):
response, error = gen(prompt, settings)
response_text_list = completions_text(response)
self.textbox.alternatives = []
inline_completions = [completion for completion in response_text_list if completion]
inline_completions.insert(0, self.textbox.get_range(selected_range[0], selected_range[1]))
self.completion_index = 0
completions_dict = {'alts': [{'text': completion} for completion in inline_completions],
'replace_range': [selected_range[0], selected_range[1]]}
self.textbox.alternatives.append(completions_dict)
self.inline_completions = completions_dict
#self.inserted_range = None
self.textbox.tag_remove("alternate", "1.0", tk.END)
self.insert_inline_completion()

self.textbox.call_model_inline(prompt, settings, selected_range)

def process_logprobs(self):
# TODO when prompt length is shorter
self.textbox.alternatives = []
if "tokens" in self.model_response['prompt']:
# check if prompt is shorter than text in textbox
prompt_length = len(self.model_response['prompt']['text'])
textbox_length = len(self.textbox.get("1.0", "end-1c"))
diff = textbox_length - prompt_length
for token_data in self.model_response['prompt']['tokens']:
#print(token_data)
if 'counterfactuals' in token_data and token_data['counterfactuals']:
alt_dict = {'alts': [],
'replace_range': [token_data['position']['start'] + diff, token_data['position']['end'] + diff],}
sorted_counterfactuals = {k: v for k, v in sorted(token_data['counterfactuals'].items(), key=lambda item: item[1], reverse=True)}
for token, prob in sorted_counterfactuals.items():
alt_dict['alts'].append({'text': token, 'logprob': prob, 'prob': logprobs_to_probs(prob)})
self.textbox.alternatives.append(alt_dict)
#print(self.textbox.alternatives)

def key_pressed(self, event):
if event.keysym == "Alt_L":
return
# else:
# self.accept_completion()

self.textbox.process_logprobs()

def insert_inline_completion(self, step=1, *args):
if self.inline_completions:
self.completion_index += step
completion = self.inline_completions['alts'][self.completion_index % len(self.inline_completions['alts'])]['text']
repl = self.inline_completions['replace_range']
self.textbox.replace_range(repl[0], repl[1], completion, "alternate")

self.textbox.insert_inline_completion(step)

def open_counterfactuals(self, event):
self.textbox.open_alt_dropdown(event)
return self.textbox.open_alt_dropdown(event)

def key_pressed(self, event):
if event.keysym == "Alt_L":
return

def export(self, *args):
pass
Expand Down
87 changes: 72 additions & 15 deletions components/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
import os
import codecs
from PIL import Image, ImageTk
from gpt import POSSIBLE_MODELS
from gpt import POSSIBLE_MODELS, gen, completions_text
from util.gpt_util import logprobs_to_probs
import json
import bisect
from util.util import split_indices
import threading


buttons = {'go': 'arrow-green',
'edit': 'edit-blue',
Expand Down Expand Up @@ -641,7 +644,7 @@ def key_pressed(self, event):



class SmartText(TextAware):
class LoomTerminal(TextAware):
"""
alternatives:
{
Expand All @@ -664,6 +667,8 @@ def __init__(self, *args, **kwargs):
self.tag_raise("sel")
self.tag_raise("insert")
self.alternatives = []
self.completion_index = None
self.inline_completions = None

def key_pressed(self, event):
pass
Expand Down Expand Up @@ -777,17 +782,9 @@ def alt_dropdown(self, alt_dict, show_probs=True):
self.tag_range("alternate", start_pos, end_pos)
# get x y coordinates of start_pos
start_index = self.index(f"1.0 + {start_pos} chars")
x, y = self.count("1.0", start_index, "xpixels", "ypixels")
# TODO adjust based on font size
x = x + self.winfo_rootx() + 5
y = y + self.winfo_rooty() + 45
# get Text scroll position
scroll_pos = self.yview()[0]
# get Text height
text_height = self.winfo_height()
scroll_offset = int(round(scroll_pos * text_height))

y = y - scroll_offset
self.update_idletasks()
bbox = self.bbox(start_index)
x, y = self.winfo_rootx() + bbox[0], self.winfo_rooty() + bbox[1] + 25

# create dropdown menu
menu = tk.Menu(self, tearoff=0)
Expand All @@ -805,7 +802,7 @@ def alt_dropdown(self, alt_dict, show_probs=True):
alt['text'],
tag="alternate"))

menu.add_separator()
#menu.add_separator()
# display current text

menu.tk_popup(x, y)
Expand Down Expand Up @@ -833,7 +830,67 @@ def open_alt_dropdown(self, event):
alt_dict = self.get_alt_dict(position)
if alt_dict:
self.alt_dropdown(alt_dict)

return True
return False

def inline_generate(self, generation_settings):
if self.tag_ranges("sel"):
self.fix_selection()
prompt = self.get("1.0", "sel.first")
selected_range = self.selected_range()
else:
self.fix_insertion()
prompt = self.get("1.0", "insert")
selected_range = [len(prompt), len(prompt)]
threading.Thread(target=self.call_model_inline, args=(prompt, generation_settings, selected_range)).start()

def call_model_inline(self, prompt, settings, selected_range):
response, error = gen(prompt, settings)
response_text_list = completions_text(response)
self.alternatives = []
inline_completions = [completion for completion in response_text_list if completion]
inline_completions.insert(0, self.get_range(selected_range[0], selected_range[1]))
self.completion_index = 0
completions_dict = {'alts': [{'text': completion} for completion in inline_completions],
'replace_range': [selected_range[0], selected_range[1]]}
self.alternatives.append(completions_dict)
self.inline_completions = completions_dict
#self.inserted_range = None
self.tag_remove("alternate", "1.0", tk.END)
self.insert_inline_completion()

def call_model_prompt(self, prompt, settings):
eval_settings = settings.copy()
eval_settings.update({'max_tokens': 1, 'num_continuations': 1, 'logprobs': 15})
response, error = gen(prompt, eval_settings)
# enable eval button
#self.eval_prompt_button.configure(state='normal')
self.model_response = response
self.process_logprobs()

def insert_inline_completion(self, step=1):
if self.inline_completions:
self.completion_index += step
completion = self.inline_completions['alts'][self.completion_index % len(self.inline_completions['alts'])]['text']
repl = self.inline_completions['replace_range']
self.replace_range(repl[0], repl[1], completion, "alternate")

def process_logprobs(self):
self.alternatives = []
if "tokens" in self.model_response['prompt']:
# check if prompt is shorter than text in textbox
prompt_length = len(self.model_response['prompt']['text'])
textbox_length = len(self.get("1.0", "end-1c"))
diff = textbox_length - prompt_length
for token_data in self.model_response['prompt']['tokens']:
#print(token_data)
if 'counterfactuals' in token_data and token_data['counterfactuals']:
alt_dict = {'alts': [],
'replace_range': [token_data['position']['start'] + diff, token_data['position']['end'] + diff],}
sorted_counterfactuals = {k: v for k, v in sorted(token_data['counterfactuals'].items(), key=lambda item: item[1], reverse=True)}
for token, prob in sorted_counterfactuals.items():
alt_dict['alts'].append({'text': token, 'logprob': prob, 'prob': logprobs_to_probs(prob)})
self.alternatives.append(alt_dict)

#################################
# Settings
Expand Down
14 changes: 2 additions & 12 deletions controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def refresh_textbox(self, **kwargs):
self.display.textbox.configure(font=Font(family="Georgia", size=self.state.preferences['font_size']),
spacing1=self.state.preferences['paragraph_spacing'],
spacing2=self.state.preferences['line_spacing'],
background=edit_color() if self.state.preferences["editable"] else bg_color())
background=edit_color() if self.state.preferences["editable"] or self.display.mode == "Edit" else bg_color())
#self.display.textbox.tag_config("node_select", font=Font(family="Georgia", size=self.state.preferences['font_size'], weight="bold"))

# Fill textbox with text history, disable editing
Expand Down Expand Up @@ -905,15 +905,6 @@ def write_textbox_changes(self):
old_text = self.state.ancestry_text(self.state.selected_node)
new_text = self.display.textbox.get("1.0", "end-1c")
if old_text != new_text:
# split_old_text = re.split(r'(\s+)', old_text)
# split_new_text = re.split(r'(\s+)', new_text)
# split_old_indices = [0]
# split_new_indices = [0]
# for i in range(len(split_old_text)):
# split_old_indices.append(split_old_indices[-1] + len(split_old_text[i]))
# for i in range(len(split_new_text)):
# split_new_indices.append(split_new_indices[-1] + len(split_new_text[i]))

dmp = diff_match_patch()
a = diff_linesToWords(old_text, new_text, delimiter=re.compile(' '))
diffs = dmp.diff_main(a[0], a[1], False)
Expand Down Expand Up @@ -1761,8 +1752,7 @@ def workspace_dialog(self):
if dialog.result:
self.refresh_workspace()


@metadata(name="Show Info", keys=["<i>", "<Control-i>"], display_key="i")
@metadata(name="Show Info", keys=["<Control-i>"], display_key="i")
def info_dialog(self):
all_text = "".join([d["text"] for d in self.state.tree_node_dict.values()])

Expand Down
2 changes: 2 additions & 0 deletions util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ def diff(old, new):
return {'added': added, 'removed': removed, 'old': old, 'new': new}


# https://evandrocoan.github.io/debugtools/html/classdebug__tools_1_1utilities_1_1diffmatchpatch.html
def diff_linesToWords(text1, text2, delimiter=re.compile('\n')):
"""
Split two texts into an array of strings. Reduce the texts to a string
of hashes where each Unicode character represents one line.
Expand Down
2 changes: 1 addition & 1 deletion view/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
DARKMODE_TEXT = "white"
DARKMODE_UNCANONICAL = BLUE

DARKMODE_EDIT_BG = '#575757'##636363'#"#747474"#'#363636'
DARKMODE_EDIT_BG = '#525252'#575757'##636363'#"#747474"#'#363636'

DARKMODE_NOT_VISITED = '#636363'
DARKMODE_VISITED = "#474747"
Expand Down
Loading

0 comments on commit d6809cd

Please sign in to comment.