Skip to content

Commit

Permalink
add gooseai and wavefunction propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
socketteer committed Feb 23, 2022
1 parent 0abb68a commit 7fe9577
Show file tree
Hide file tree
Showing 10 changed files with 289 additions and 101 deletions.
32 changes: 20 additions & 12 deletions components/block_multiverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, parent_frame):
self.canvas = None
self.wavefunction = None
self.selected_id = None
self.window_height = 1000
self.window_height = 450
self.node_info = {}
self.build_canvas()
self.build_past_box()
Expand All @@ -49,23 +49,26 @@ def build_canvas(self):
self.multiverse_frame.pack(expand=True, fill=tkinter.BOTH)
self.canvas = tkinter.Canvas(self.multiverse_frame, bg="#808080")

hbar = tkinter.Scrollbar(self.multiverse_frame, orient=tkinter.HORIZONTAL)
hbar.pack(side=tkinter.BOTTOM, fill=tkinter.X)
hbar.config(command=self.canvas.xview)
# hbar = tkinter.Scrollbar(self.multiverse_frame, orient=tkinter.HORIZONTAL)
# hbar.pack(side=tkinter.BOTTOM, fill=tkinter.X)
# hbar.config(command=self.canvas.xview)

vbar = tkinter.Scrollbar(self.multiverse_frame, orient=tkinter.VERTICAL)
vbar.pack(side=tkinter.RIGHT, fill=tkinter.Y)
vbar.config(command=self.canvas.yview)
# vbar = tkinter.Scrollbar(self.multiverse_frame, orient=tkinter.VERTICAL)
# vbar.pack(side=tkinter.RIGHT, fill=tkinter.Y)
# vbar.config(command=self.canvas.yview)

self.canvas.config(
xscrollcommand=hbar.set,
yscrollcommand=vbar.set
)
# self.canvas.config(
# xscrollcommand=hbar.set,
# yscrollcommand=vbar.set
# )

self.canvas.pack(side=tkinter.LEFT, expand=True, fill=tkinter.BOTH)
#self.multiverse_frame.update_idletasks()
#self.window_height = self.multiverse_frame.winfo_reqheight() * 2

def build_past_box(self):
self.bottom_input_frame = ttk.Frame(self.frame)
self.bottom_input_frame.pack(side="bottom", fill="both")
self.bottom_input_frame.pack(side="bottom", fill="x")
self.past_box = TextAware(self.bottom_input_frame, bd=3, height=3)
self.past_box.pack(expand=True, fill='x')
self.past_box.configure(
Expand Down Expand Up @@ -195,6 +198,7 @@ def draw_multiverse(self, multiverse, ground_truth='', block_width=150, start_po
if not self.prompt:
self.prompt = prompt
self.set_pastbox_text(prompt_text=self.prompt)

if not self.wavefunction:
self.wavefunction = multiverse
else:
Expand All @@ -203,6 +207,9 @@ def draw_multiverse(self, multiverse, ground_truth='', block_width=150, start_po
prefix = self.node_info[self.selected_id]['prefix']
else:
return
if start_position == (0, 0):
self.draw_block(0, 0, self.prompt[-20:], prefix, 1, Decimal(self.window_height), block_width, True,
show_text, 0)
self.propagate(multiverse, ground_truth, prefix, block_width, start_position, color_index, show_text,
y_offset=0, depth=1)

Expand Down Expand Up @@ -312,6 +319,7 @@ def draw_block(self, x, y, token, prompt, probability, height, block_width, is_g
def map_to_scaled_coordinates(self, x, y):
x = x - self.window_offset[0]
y = y - self.window_offset[1]
print(y)
y = y * self.y_scale
return x, y

Expand Down
11 changes: 8 additions & 3 deletions components/dialogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tkinter.font import Font
from tkinter.scrolledtext import ScrolledText

from gpt import POSSIBLE_MODELS
# from gpt import POSSIBLE_MODELS
from util.custom_tks import Dialog, TextAware
from util.util_tk import create_side_label, create_label, Entry, create_button, create_slider, create_combo_box, create_checkbutton
from util.util_tree import search, node_ancestry
Expand Down Expand Up @@ -1164,23 +1164,27 @@ def __init__(self, parent, state):
self.selected_model = tk.StringVar()
self.add_model_button = None
self.openai_api_key_entry = None
self.ai21_api_key_entry = None
self.openai_api_key = None
self.ai21_api_key_entry = None
self.ai21_api_key = None
self.gooseai_api_key = None
self.gooseai_api_key_entry = None
Dialog.__init__(self, parent, title="Model Configuration")

def set_vars(self):
self.available_models = self.state.model_config['models']
self.selected_model.set(self.state.generation_settings['model'])
self.openai_api_key = self.state.OPENAI_API_KEY if self.state.OPENAI_API_KEY else ""
self.ai21_api_key = self.state.AI21_API_KEY if self.state.AI21_API_KEY else ""
self.gooseai_api_key = self.state.GOOSEAI_API_KEY if self.state.GOOSEAI_API_KEY else ""

def body(self, master):
self.set_vars()
self.add_model_button = ttk.Button(master, text="Add Model", command=self.add_model)
key_length = max(max(len(self.openai_api_key), len(self.ai21_api_key)), 20)
key_length = max(max(len(self.openai_api_key), len(self.ai21_api_key), len(self.gooseai_api_key)), 20)
self.openai_api_key_entry = Entry(master, master.grid_size()[1], "OpenAI API Key", self.openai_api_key, None, width=key_length)
self.ai21_api_key_entry = Entry(master, master.grid_size()[1], "AI21 API Key", self.ai21_api_key, None, width=key_length)
self.gooseai_api_key = Entry(master, master.grid_size()[1], "GooseAI API Key", self.gooseai_api_key, None, width=key_length)
models_list = self.available_models.keys()
self.model_label = ttk.Label(master, text="Model")
self.model_label.grid(row=master.grid_size()[1], column=0)
Expand All @@ -1202,5 +1206,6 @@ def apply(self):
#'AI21_API_KEY': self.ai21_api_key_entry.tk_variables.get(),
self.state.OPENAI_API_KEY = self.openai_api_key_entry.tk_variables.get()
self.state.AI21_API_KEY = self.ai21_api_key_entry.tk_variables.get()
self.state.GOOSEAI_API_KEY = self.gooseai_api_key_entry.tk_variables.get()
self.state.update_user_frame(update={'generation_settings': {'model': self.selected_model.get()}})

129 changes: 128 additions & 1 deletion components/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import os
import json

from components.block_multiverse import BlockMultiverse

icons = Icons()


Expand Down Expand Up @@ -1898,4 +1900,129 @@ def selection_updated(self):

def tree_updated(self):
self.refresh()
self.read()
self.read()



class Wavefunction(Module):
def __init__(self, callbacks, state):
self.wavefunction = None
self.buttons_frame = None
self.config_frame = None
# config: model, max_depth, threshold
self.model = tk.StringVar()
self.threshold = tk.DoubleVar()
self.max_depth = tk.IntVar()
self.max_depth_entry = None
self.threshold_entry = None
self.model_dropdown = None
# buttons: propagate, clear, center, add path to tree
self.propagate_button = None
self.clear_button = None
self.add_path_button = None
self.reset_zoom_button = None
self.model_list = ["ada", "babbage", "curie", "davinci", "gpt-j-6b", "gpt-neo-20b"]

self.ground_truth_textbox = None
Module.__init__(self, 'wavefunction', callbacks, state)


def build(self, parent):
Module.build(self, parent)
self.config_frame = ttk.Frame(self.frame)
self.config_frame.pack(side=tk.TOP, fill=tk.X)
model_label = ttk.Label(self.config_frame, text="Model:")
model_label.pack(side=tk.LEFT)
self.model_dropdown = ttk.OptionMenu(self.config_frame, self.model, *self.model_list)
self.model_dropdown.pack(side=tk.LEFT, padx=10)

max_depth_label = ttk.Label(self.config_frame, text="Max depth:")
max_depth_label.pack(side=tk.LEFT)
self.max_depth_entry = ttk.Entry(self.config_frame, textvariable=self.max_depth, width=5)
self.max_depth_entry.pack(side=tk.LEFT, padx=10)

self.textboxes.append(self.max_depth_entry)


threshold_label = ttk.Label(self.config_frame, text="Cutoff threshold:")
threshold_label.pack(side=tk.LEFT)
self.threshold_entry = ttk.Entry(self.config_frame, textvariable=self.threshold, width=6)
self.threshold_entry.pack(side=tk.LEFT)

self.textboxes.append(self.threshold_entry)


self.wavefunction = BlockMultiverse(self.frame)
self.wavefunction.frame.pack(side=tk.TOP, expand=True, fill=tk.BOTH)

self.ground_truth_textbox = TextAware(self.frame, height=1, bd=2, undo=True)
self.ground_truth_textbox.pack(side=tk.TOP, expand=False, fill=tk.X)
self.ground_truth_textbox.configure(
foreground=text_color(),
background=edit_color(),
wrap="word",
)

self.textboxes.append(self.ground_truth_textbox)

self.buttons_frame = ttk.Frame(self.frame)
self.buttons_frame.pack(side='bottom', fill='x')
self.propagate_button = ttk.Button(self.buttons_frame, text="Propagate", compound='right', command=self.propagate)
self.propagate_button.pack(side='left')
self.clear_button = ttk.Button(self.buttons_frame, text="Clear", compound='right', command=self.clear)
self.clear_button.pack(side='left')
self.reset_zoom_button = ttk.Button(self.buttons_frame, text="Reset zoom", compound='right', command=self.reset_zoom)
self.reset_zoom_button.pack(side='left')
self.add_path_button = ttk.Button(self.buttons_frame, text="Add path to tree", compound='right', command=self.add_path)
self.add_path_button.pack(side='left')

self.set_config()


def set_config(self):
current_model = self.state.generation_settings['model']
self.model.set(current_model if current_model in self.model_list else "ada")
self.max_depth.set(3)
self.threshold.set(0.1)


def propagate(self):
if self.wavefunction.active_wavefunction():
active_node = self.wavefunction.active_info()
start_position = (active_node['x'], active_node['y'])
multiverse, ground_truth, prompt = self.state.generate_greedy_multiverse(max_depth=self.max_depth.get(),
prompt=active_node['prefix'],
unnormalized_amplitude=active_node['amplitude'],
ground_truth=self.ground_truth_textbox.get(1.0, tk.END),
threshold=self.threshold.get(),
engine=self.model.get())
else:
start_position = (0, 0)
multiverse, ground_truth, prompt = self.state.generate_greedy_multiverse(max_depth=self.max_depth.get(),
ground_truth=self.ground_truth_textbox.get(1.0, tk.END),
threshold=self.threshold.get(),
engine=self.model.get()
)

self.wavefunction.draw_multiverse(multiverse=multiverse, ground_truth=ground_truth,
start_position=start_position, prompt=prompt)

def clear(self):
self.wavefunction.clear_multiverse()

def reset_zoom(self):
self.wavefunction.reset_view()

def add_path(self):
if self.wavefunction.active_wavefunction():
active_node = self.wavefunction.active_info()
prompt=active_node['prefix']
new_child = self.state.create_child(self.state.selected_node, expand=True)
new_child['text'] = prompt
self.state.tree_updated(add=[new_child['id']])

def tree_updated(self):
pass

def selection_updated(self):
self.clear()
2 changes: 1 addition & 1 deletion components/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import codecs
from PIL import Image, ImageTk
from gpt import POSSIBLE_MODELS, gen, completions_text
from gpt import gen, completions_text
import json
import bisect
import threading
Expand Down
Loading

0 comments on commit 7fe9577

Please sign in to comment.