Skip to content

Commit

Permalink
save wavefunction image
Browse files Browse the repository at this point in the history
  • Loading branch information
socketteer committed Feb 23, 2022
1 parent 4979101 commit 9249a25
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
14 changes: 11 additions & 3 deletions components/block_multiverse.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import math
import tkinter
import uuid
import openai

from tkinter import ttk
from decimal import *
from util.custom_tks import TextAware
from util.gpt_util import logprobs_to_probs
from util.tokenizer import tokenize, token_to_word

from PIL import Image
import PIL.ImageGrab as ImageGrab

rainbow_colors = ['#9400D3', '#4B0082', '#0000FF', '#00FF00', '#FFFF00', '#FF7F00', '#FF0000']

Expand Down Expand Up @@ -319,7 +320,6 @@ def draw_block(self, x, y, token, prompt, probability, height, block_width, is_g
def map_to_scaled_coordinates(self, x, y):
x = x - self.window_offset[0]
y = y - self.window_offset[1]
print(y)
y = y * self.y_scale
return x, y

Expand All @@ -341,3 +341,11 @@ def draw_text_absolute(self, x, y, **kwargs):
#rel_y = int(round(rel_y))
return self.canvas.create_text(rel_x, rel_y, **kwargs)

def save_as_png(self, filename):
# grabcanvas=ImageGrab.grab(bbox=self.canvas).save("test.png")
# ttk.grabcanvas.save("test.png")

self.canvas.postscript(file = filename + '.eps')
# use PIL to convert to PNG
img = Image.open(filename + '.eps')
img.save(filename + '.png', 'png', quality=100)
11 changes: 9 additions & 2 deletions components/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,12 +1916,13 @@ def __init__(self, callbacks, state):
self.max_depth_entry = None
self.threshold_entry = None
self.model_dropdown = None
# buttons: propagate, clear, center, add path to tree
# buttons: propagate, clear, center, add path to tree, save image
self.propagate_button = None
self.clear_button = None
self.add_path_button = None
self.reset_zoom_button = None
self.model_list = ["ada", "babbage", "curie", "davinci", "gpt-j-6b", "gpt-neo-20b"]
self.save_image_button = None
self.model_list = ["ada", "ada", "babbage", "curie", "davinci", "gpt-j-6b", "gpt-neo-20b"]

self.ground_truth_textbox = None
Module.__init__(self, 'wavefunction', callbacks, state)
Expand Down Expand Up @@ -1975,6 +1976,8 @@ def build(self, parent):
self.reset_zoom_button.pack(side='left')
self.add_path_button = ttk.Button(self.buttons_frame, text="Add path to tree", compound='right', command=self.add_path)
self.add_path_button.pack(side='left')
self.save_image_button = ttk.Button(self.buttons_frame, text="Save image", compound='right', command=self.save_image)
self.save_image_button.pack(side='left')

self.set_config()

Expand Down Expand Up @@ -2021,6 +2024,10 @@ def add_path(self):
new_child['text'] = prompt
self.state.tree_updated(add=[new_child['id']])

def save_image(self):
prompt = self.state.default_prompt(quiet=True, node=self.state.selected_node)
self.wavefunction.save_as_png(f'{prompt[-20:]}_{self.model.get()}.png')

def tree_updated(self):
pass

Expand Down

0 comments on commit 9249a25

Please sign in to comment.