Skip to content

Commit

Permalink
counterfactual tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
socketteer committed Apr 17, 2021
1 parent 1a75c62 commit 1434976
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 51 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ Child edit: `c`

### Navigate

Click to go to node: `Control-shift-click`

Next: `period`, `Return`, `Control-period`

Prev: `comma`, `Control-comma`
Expand Down Expand Up @@ -113,6 +115,9 @@ Toggle bookmark: `b`, `Control-b`

Search: `Control-f`

Click to split node: `Control-alt-click`



### Edit topology

Expand All @@ -139,6 +144,16 @@ New Sibling: `Alt-Down`

### Edit text

Click to edit history: `Control-click`

Click to select token: `Alt-click`

Next counterfactual token: `Alt-period`

Previous counterfactual token: `Alt-comma`

Apply counterfactual changes: `Alt-return`

Enter text: `Control-bar`

Escape textbox: `Escape`
Expand Down
5 changes: 4 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

- counterfactual tokens
- save changes or make new branch
- highlight indicates whether edits have been saved
- highlight indicates whether edits have been saved
- save modified state

- new memory system

# bugs

Expand Down
127 changes: 98 additions & 29 deletions controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import pyperclip
import bisect

import traceback

from view.colors import history_color, not_visited_color, visited_color, ooc_color, text_color, uncanonical_color
from view.display import Display
from view.dialogs import GenerationSettingsDialog, InfoDialog, VisualizationSettingsDialog, \
Expand Down Expand Up @@ -88,6 +90,7 @@ def register_model_callbacks(self):
self.state.register_callback(self.state.selection_updated, self.refresh_textbox)
self.state.register_callback(self.state.selection_updated, self.refresh_vis_selection)
self.state.register_callback(self.state.selection_updated, self.refresh_notes)
self.state.register_callback(self.state.selection_updated, self.refresh_counterfactual_meta)


def setup_key_bindings(self):
Expand Down Expand Up @@ -190,7 +193,8 @@ def next(self):
if self.state.preferences["canonical_only"]:
self.state.next_canonical()
else:
self.state.traverse_tree(1)
#self.state.traverse_tree(1)
self.select_node(node=self.state.next(1))
# next.meta = dict(name="Next", keys=["<period>", "<Return>"], display_key=">")
# .meta = dict(name=, keys=, display_key=)

Expand All @@ -199,34 +203,41 @@ def prev(self):
if self.state.preferences["canonical_only"]:
self.state.prev_canonical()
else:
self.state.traverse_tree(-1)
#self.state.traverse_tree(-1)
self.select_node(node=self.state.next(-1))

@metadata(name="Go to parent", keys=["<Left>", "<Control-Left>"], display_key="←")
def parent(self):
self.state.select_parent()
#self.state.select_parent()
self.select_node(node=self.state.parent())

@metadata(name="Go to child", keys=["<Right>", "<Control-Right>"], display_key="→")
def child(self):
self.state.select_child(0)
#self.state.select_child(0)
self.select_node(node=self.state.child())

@metadata(name="Go to next sibling", keys=["<Down>", "<Control-Down>"], display_key="↓")
def next_sibling(self):
self.state.select_sibling(1)
#self.state.select_sibling(1)
self.select_node(node=self.state.sibling(1))

@metadata(name="Go to previous Sibling", keys=["<Up>", "<Control-Up>"], display_key="↑")
def prev_sibling(self):
self.state.select_sibling(-1)
#self.state.select_sibling(-1)
self.select_node(node=self.state.sibling(-1))

@metadata(name="Walk", keys=["<Key-w>", "<Control-w>"], display_key="w")
def walk(self, canonical_only=False):
filter_set = self.state.calc_canonical_set() if canonical_only else None
if 'children' in self.state.selected_node and len(self.state.selected_node['children']) > 0:
chosen_child = stochastic_transition(self.state.selected_node, mode='descendents', filter_set=filter_set)
self.state.select_node(chosen_child['id'])
#self.state.select_node(chosen_child['id'])
self.select_node(node=chosen_child)

@metadata(name="Return to root", keys=["<Key-r>", "<Control-r>"], display_key="r")
def return_to_root(self):
self.state.select_node(self.state.tree_raw_data["root"]["id"])
#self.state.select_node(self.state.tree_raw_data["root"]["id"])
self.select_node(node=self.state.tree_raw_data["root"])

@metadata(name="Save checkpoint", keys=["<Control-t>"], display_key="ctrl-t")
def save_checkpoint(self, node=None):
Expand All @@ -238,19 +249,21 @@ def save_checkpoint(self, node=None):
@metadata(name="Go to checkpoint", keys=["<Key-t>"], display_key="t")
def goto_checkpoint(self):
if self.state.checkpoint:
self.state.select_node(self.state.checkpoint)
#self.state.select_node(self.state.checkpoint)
self.select_node(node=self.state.tree_node_dict[self.state.checkpoint])

@metadata(name="Nav Select")
def nav_select(self, *, node_id):
if not node_id:
if not node_id or node_id == self.state.selected_node_id:
return
if self.change_parent.meta["click_mode"]:
self.change_parent(node=self.state.tree_node_dict[node_id])
# TODO This causes infinite recursion from the vis node. Need to change how updating open status works
# Update the open state of the node based on the nav bar
# node = self.state.tree_node_dict[node_id]
# node["open"] = self.display.nav_tree.item(node["id"], "open")
self.state.select_node(node_id)
#self.state.select_node(node_id)
self.select_node(node=self.state.tree_node_dict[node_id])

@metadata(name="Bookmark", keys=["<Key-b>", "<Control-b>"], display_key="b")
def bookmark(self, node=None):
Expand Down Expand Up @@ -293,17 +306,20 @@ def center_view(self):
#self.display.vis.center_view_on_canvas_coords(*self.display.vis.node_coords[self.state.selected_node_id])
self.display.vis.center_view_on_node(self.state.selected_node)

@metadata(name="Select node")
def select_node(self, node):
if self.state.preferences['coloring'] == 'read':
old_node = self.state.selected_node
self.state.select_node(node['id'])
_, index = nearest_common_ancestor(old_node, node, self.state.tree_node_dict)
nca_node, index = nearest_common_ancestor(old_node, node, self.state.tree_node_dict)
nca_end_index = self.ancestor_end_indices[index]
self.display.textbox.tag_delete("old")
self.display.textbox.tag_add("old",
"1.0",
f"1.0 + {nca_end_index} chars")
self.display.textbox.tag_config("old", foreground=history_color())
#traceback.print_stack()
print('done')
else:
self.state.select_node(node['id'])

Expand Down Expand Up @@ -456,7 +472,7 @@ def split_node(self, index, change_selection=True):
self.nav_select(node_id=new_parent["id"])
# TODO deal with metadata

@metadata(name="Select token", keys=[], display_key="")
@metadata(name="Select token", keys=[], display_key="", selected_node=None, token_index=None)
def select_token(self, index):
if self.display.mode == "Read":
self.display.textbox.tag_remove("selected", "1.0", 'end')
Expand All @@ -466,11 +482,12 @@ def select_token(self, index):
offset = len(selected_node['text']) - negative_offset

# TODO new token offsets if changed
if "meta" in selected_node and "generation" in selected_node["meta"]:
if "meta" in selected_node and "generation" in selected_node["meta"] and not selected_node['meta']['modified']:
self.change_token.meta["counterfactual_index"] = 0
self.change_token.meta["prev_token"] = None
token_offsets = [n - len(selected_node['meta']['generation']['prompt'])
for n in selected_node['meta']['generation']["logprobs"]["text_offset"]]
#token_offsets = [n - len(selected_node['meta']['generation']['prompt'])
# for n in selected_node['meta']['generation']["logprobs"]["text_offset"]]
token_offsets = selected_node['meta']['generation']["logprobs"]["text_offset"]

token_index = bisect.bisect_left(token_offsets, offset) - 1
start_position = token_offsets[token_index]
Expand All @@ -489,42 +506,87 @@ def select_token(self, index):
self.select_token.meta["selected_node"] = selected_node
self.select_token.meta["token_index"] = token_index

@metadata(name="Change token", keys=["<Control-Shift-KeyPress-Down>"], display_key="", counterfactual_index=0, prev_token=None)
def change_token(self, node=None, token_index=None):
@metadata(name="Change token", keys=[], display_key="", counterfactual_index=0, prev_token=None, temp_token_offsets=None)
def change_token(self, node=None, token_index=None, traverse=1):
if not self.select_token.meta["selected_node"]:
return
elif not node:
node = self.select_token.meta["selected_node"]
token_index = self.select_token.meta["token_index"]
token_offsets = [n - len(node['meta']['generation']['prompt'])
for n in node['meta']['generation']["logprobs"]["text_offset"]]

if not self.change_token.meta['temp_token_offsets']:
#token_offsets = [n - len(node['meta']['generation']['prompt'])
# for n in node['meta']['generation']["logprobs"]["text_offset"]]
token_offsets = node['meta']['generation']["logprobs"]["text_offset"]
self.change_token.meta['temp_token_offsets'] = token_offsets
else:
token_offsets = self.change_token.meta['temp_token_offsets']

start_position = token_offsets[token_index]
token = node['meta']['generation']["logprobs"]["tokens"][token_index]
counterfactuals = node['meta']['generation']["logprobs"]["top_logprobs"][token_index].copy()
original_token = (token, counterfactuals.pop(token, None))
index = node_index(node, self.state.tree_node_dict)
sorted_counterfactuals = list(sorted(counterfactuals.items(), key=lambda item: item[1], reverse=True))
sorted_counterfactuals.insert(0, original_token)
#print('start position: ', self.ancestor_end_indices[index - 1] + start_position)
print(sorted_counterfactuals)
self.change_token.meta["counterfactual_index"] += 1

self.change_token.meta["counterfactual_index"] += traverse
if self.change_token.meta["counterfactual_index"] < 0 or self.change_token.meta["counterfactual_index"] > len(sorted_counterfactuals) - 1:
self.change_token.meta["counterfactual_index"] -= traverse
return

new_token = sorted_counterfactuals[self.change_token.meta["counterfactual_index"]][0]
self.display.textbox.config(state="normal")
if not self.change_token.meta['prev_token']:
self.change_token.meta['prev_token'] = token


self.display.textbox.config(state="normal")


self.display.textbox.delete(f"1.0 + {self.ancestor_end_indices[index - 1] + start_position} chars",
f"1.0 + {self.ancestor_end_indices[index - 1] + start_position + len(self.change_token.meta['prev_token'])} chars")
self.display.textbox.insert(f"1.0 + {self.ancestor_end_indices[index - 1] + start_position} chars", new_token)

self.display.textbox.config(state="disabled")
self.display.textbox.tag_add("selected",
self.display.textbox.tag_add("modified",
f"1.0 + {self.ancestor_end_indices[index - 1] + start_position} chars",
f"1.0 + {self.ancestor_end_indices[index - 1] + start_position + len(new_token)} chars")


#update temp token offsets
diff = len(new_token) - len(self.change_token.meta['prev_token'])
for index, offset in enumerate(self.change_token.meta['temp_token_offsets'][token_index:]):
self.change_token.meta['temp_token_offsets'][index] += diff

self.change_token.meta['prev_token'] = new_token

@metadata(name="Reset zoom", keys=["<Control-0>"], display_key="Ctrl-0")
def reset_zoom(self):
self.display.vis.reset_zoom()
@metadata(name="Next token", keys=["<Alt-period>"], display_key="", counterfactual_index=0, prev_token=None)
def next_token(self, node=None, token_index=None):
self.change_token(node, token_index, traverse=1)

@metadata(name="Prev token", keys=["<Alt-comma>"], display_key="", counterfactual_index=0, prev_token=None)
def prev_token(self, node=None, token_index=None):
self.change_token(node, token_index, traverse=-1)

@metadata(name="Apply counterfactual", keys=["<Alt-Return>"], display_key="", counterfactual_index=0, prev_token=None)
def apply_counterfactual_changes(self):
# TODO apply to non selected nodes
index = node_index(self.state.selected_node, self.state.tree_node_dict)

new_text = self.display.textbox.get(f"1.0 + {self.ancestor_end_indices[index - 1]} chars", "end-1c")
self.state.update_text(node=self.state.selected_node, text=new_text, modified_flag=False)
self.display.textbox.tag_remove("modified", "1.0", 'end-1c')

# TODO don't do this if no meta
self.state.selected_node['meta']['generation']["logprobs"]["text_offset"] = self.change_token.meta['temp_token_offsets']
self.refresh_counterfactual_meta()


@metadata(name="Refresh counterfactual")
def refresh_counterfactual_meta(self):
self.change_token.meta['prev_token'] = None
self.change_token.meta['temp_token_offsets'] = None


#################################
# State
Expand Down Expand Up @@ -900,7 +962,9 @@ def refresh_textbox(self, **kwargs):
self.display.textbox.tag_config('ooc_history', foreground=text_color())
self.display.textbox.tag_config('history', foreground=text_color())

self.display.textbox.tag_config("selected", background="blue", foreground=text_color())
# TODO bad color for lightmode
self.display.textbox.tag_config("selected", background="black", foreground=text_color())
self.display.textbox.tag_config("modified", background="blue", foreground=text_color())
ancestry, indices = self.state.node_ancestry_text()
self.ancestor_end_indices = indices
history = ''
Expand Down Expand Up @@ -981,6 +1045,11 @@ def refresh_vis_selection(self):
# self.display.vis.draw(self.state.tree_raw_data["root"], self.state.selected_node)
# self.display.vis.center_view_on_canvas_coords(*self.display.vis.node_coords[self.state.selected_node_id])

@metadata(name="Reset zoom", keys=["<Control-0>"], display_key="Ctrl-0")
def reset_zoom(self):
self.display.vis.reset_zoom()


def refresh_notes(self):
if not self.state.tree_raw_data or not self.state.selected_node or not self.state.preferences['side_pane']:
return
Expand Down
37 changes: 31 additions & 6 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,6 @@ def node(self, node_id=None):
return self.selected_node
return self.tree_node_dict[node_id] if self.tree_node_dict and node_id in self.tree_node_dict else None

def parent(self, node):
return self.tree_node_dict[node['parent_id']]

# Get a nodes chapter by finding its chapter or its nearest parent's chapter
def chapter(self, node):
chapter_id = get_inherited_attribute("chapter_id", node, self.tree_node_dict)
Expand Down Expand Up @@ -268,9 +265,11 @@ def traverse_tree(self, offset):
return self.select_node(new_node_id)

def next_id(self, offset):
new_idx = clip_num(self.tree_traversal_idx + offset, 0, len(self.tree_node_dict) - 1)
return self.nodes[new_idx]["id"]
return self.next(offset)["id"]

def next(self, offset):
new_idx = clip_num(self.tree_traversal_idx + offset, 0, len(self.tree_node_dict) - 1)
return self.nodes[new_idx]

# TODO this is bad
def next_canonical(self):
Expand Down Expand Up @@ -311,6 +310,23 @@ def select_sibling(self, offset, node=None):
sibling = siblings[(siblings.index(node) + offset) % len(siblings)]
return self.select_node(sibling["id"])

# return parent
def parent(self, node=None):
node = node if node else self.selected_node
return self.tree_node_dict[node['parent_id']]

# return child
def child(self, child_num, node=None):
node = node if node else self.selected_node
if node and len(node["children"]) > 0:
return index_clip(node["children"], child_num)["id"]

# return sibling
def sibling(self, offset, node=None):
node = node if node else self.selected_node
if node and "parent_id" in node:
siblings = self.parent(node)["children"]
return siblings[(siblings.index(node) + offset) % len(siblings)]

#################################
# Updates
Expand Down Expand Up @@ -455,7 +471,7 @@ def delete_node(self, node=None, reassign_children=False):
self.select_node(siblings[old_index % len(siblings)]["id"])
self.tree_updated(delete=[node['id']])

def update_text(self, node, text, active_text=None):
def update_text(self, node, text, active_text=None, modified_flag=True):
assert node["id"] in self.tree_node_dict, text

# Remove trailing spaces
Expand All @@ -478,6 +494,10 @@ def update_text(self, node, text, active_text=None):
edited = True

if edited:
if modified_flag:
if not node['meta']:
node['meta'] = {}
node['meta']['modified'] = True
self.tree_updated(edit=[node['id']])


Expand Down Expand Up @@ -705,6 +725,11 @@ def generate_for_nodes(self, prompt, nodes, grandchildren=None):
node["meta"]["modified"] = False
node["meta"]["origin"] = "generated"

# remove offset of prompt
# TODO fix old nodes
corrected_text_offset = [n - len(prompt) for n in node['meta']['generation']["logprobs"]["text_offset"]]
node['meta']['generation']["logprobs"]["text_offset"] = corrected_text_offset

else:
print("ERROR. Deleting failures")
for node in nodes:
Expand Down
Loading

0 comments on commit 1434976

Please sign in to comment.