Skip to content

Commit

Permalink
add physical traversal through beam states
Browse files Browse the repository at this point in the history
  • Loading branch information
dpfried committed May 22, 2018
1 parent 8beac87 commit ba393a1
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 15 deletions.
90 changes: 80 additions & 10 deletions tasks/R2R/follower.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@

InferenceState = namedtuple("InferenceState", "prev_inference_state, world_state, observation, flat_index, last_action, last_action_embedding, action_count, score, h_t, c_t, last_alpha")

Cons = namedtuple("Cons", "first, rest")

def cons_to_list(cons):
l = []
while True:
l.append(cons.first)
cons = cons.rest
if cons is None:
break
return l

def backchain_inference_states(last_inference_state):
states = []
observations = []
Expand All @@ -38,6 +49,29 @@ def backchain_inference_states(last_inference_state):
scores.append(last_score)
return list(reversed(states)), list(reversed(observations)), list(reversed(actions))[1:], list(reversed(scores))[1:], list(reversed(attentions))[1:] # exclude start action

def least_common_viewpoint_path(inf_state_a, inf_state_b):
# return inference states traversing from A to X, then from Y to B,
# where X and Y are the least common ancestors of A and B respectively that share a viewpointId
path_to_b_by_viewpoint = {
}
b = inf_state_b
b_stack = Cons(b, None)
while b is not None:
path_to_b_by_viewpoint[b.world_state.viewpointId] = b_stack
b = b.prev_inference_state
b_stack = Cons(b, b_stack)
a = inf_state_a
path_from_a = [a]
while a is not None:
vp = a.world_state.viewpointId
if vp in path_to_b_by_viewpoint:
path_to_b = cons_to_list(path_to_b_by_viewpoint[vp])
assert path_from_a[-1].world_state.viewpointId == path_to_b[0].world_state.viewpointId
return path_from_a + path_to_b[1:]
a = a.prev_inference_state
path_from_a.append(a)
raise AssertionError("no common ancestor found")

def batch_instructions_from_encoded(encoded_instructions, max_length, reverse=False, sort=False):
# encoded_instructions: list of lists of token indices (should not be padded, or contain BOS or EOS tokens)
#seq_tensor = np.array(encoded_instructions)
Expand Down Expand Up @@ -302,7 +336,7 @@ def rollout(self):
return self._rollout_with_loss()
else:
assert self.beam_size >= 1
beams = self.beam_search(self.beam_size)
beams, _, _ = self.beam_search(self.beam_size)
return [beam[0] for beam in beams]

def _score_obs_actions_and_instructions(self, path_obs, path_actions, encoded_instructions):
Expand Down Expand Up @@ -680,7 +714,8 @@ def beam_search(self, beam_size, load_next_minibatch=True, mask_undo=False):
'attentions': path_attentions
})
trajs.append(this_trajs)
return trajs
traversed_lists = None # todo
return trajs, completed, traversed_lists

def state_factored_search(self, completion_size, successor_size, load_next_minibatch=True, mask_undo=False, first_n_ws_key=4):
assert self.env.beam_size >= successor_size
Expand All @@ -702,19 +737,47 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib

state_cache = [
{ws[0][0:first_n_ws_key]: (InferenceState(prev_inference_state=None,
world_state=ws[0],
observation=o[0],
flat_index=None,
last_action=-1,
last_action_embedding=self.decoder.u_begin.view(-1),
action_count=0,
score=0.0, h_t=h_t[i], c_t=c_t[i], last_alpha=None), True)}
world_state=ws[0],
observation=o[0],
flat_index=None,
last_action=-1,
last_action_embedding=self.decoder.u_begin.view(-1),
action_count=0,
score=0.0, h_t=h_t[i], c_t=c_t[i], last_alpha=None), True)}
for i, (ws, o) in enumerate(zip(world_states, initial_obs))
]

beams = [[inf_state for world_state, (inf_state, expanded) in sorted(instance_cache.items())]
for instance_cache in state_cache] # sorting is a noop here since each instance_cache should only contain one


# traversed_lists = None
# list of inference states containing states in order of the states being expanded
last_expanded_list = []
traversed_lists = []
for beam in beams:
assert len(beam) == 1
first_state = beam[0]
last_expanded_list.append(first_state)
traversed_lists.append([first_state])

def update_traversed_lists(new_visited_inf_states):
assert len(new_visited_inf_states) == len(last_expanded_list)
assert len(new_visited_inf_states) == len(traversed_lists)

for instance_index, instance_states in enumerate(new_visited_inf_states):
last_expanded = last_expanded_list[instance_index]
# todo: if this passes, shouldn't need traversed_lists
assert last_expanded.world_state.viewpointId == traversed_lists[instance_index][-1].world_state.viewpointId
for inf_state in instance_states:
path_from_last_to_next = least_common_viewpoint_path(last_expanded, inf_state)
# path_from_last should include last_expanded's world state as the first element, so check and drop that
assert path_from_last_to_next[0].world_state.viewpointId == last_expanded.world_state.viewpointId
assert path_from_last_to_next[-1].world_state.viewpointId == inf_state.world_state.viewpointId
traversed_lists[instance_index].extend(path_from_last_to_next[1:])
last_expanded = inf_state
last_expanded_list[instance_index] = last_expanded

# Do a sequence rollout and calculate the loss
while any(len(comp) < completion_size for comp in completed):
beam_indices = []
Expand Down Expand Up @@ -873,6 +936,7 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib
successor_obs = self.env.observe(world_states, beamed=True)
beams = structured_map(lambda inf_state, obs: inf_state._replace(observation=obs),
beams, successor_obs, nested=True)
update_traversed_lists(beams)

completed_list = []
for this_completed in completed:
Expand All @@ -884,6 +948,10 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib
completed_obs = self.env.observe(completed_ws, beamed=True)
completed_list = structured_map(lambda inf_state, obs: inf_state._replace(observation=obs),
completed_list, completed_obs, nested=True)
# TODO: consider moving observations and this update earlier so that we don't have to traverse as far back
update_traversed_lists(completed_list)

# TODO: sanity check the traversed lists here

trajs = []
for this_completed in completed_list:
Expand All @@ -907,7 +975,9 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib
'attentions': path_attentions
})
trajs.append(this_trajs)
return trajs
# completed_list: list of lists of final inference states corresponding to the candidates, one list per instance
# traversed_lists: list of "physical states" that the robot has explored, one per instance
return trajs, completed_list, traversed_lists

def set_beam_size(self, beam_size):
if self.env.beam_size < beam_size:
Expand Down
28 changes: 23 additions & 5 deletions tasks/R2R/rational_follower.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import MatterSim
import env

from follower import least_common_viewpoint_path, path_element_from_observation

import numpy as np

from collections import namedtuple, Counter
Expand All @@ -18,7 +20,7 @@ def run_rational_follower(
envir, evaluator, follower, speaker, beam_size,
include_gold=False, output_file=None, eval_file=None,
compute_oracle=False, mask_undo=False, state_factored_search=False,
state_first_n_ws_key=4):
state_first_n_ws_key=4, physical_traversal=False):
follower.env = envir
envir.reset_epoch()

Expand Down Expand Up @@ -46,9 +48,9 @@ def run_rational_follower(

follower.feedback = feedback_method
if state_factored_search:
beam_candidates = follower.state_factored_search(beam_size, 1, load_next_minibatch=not include_gold, mask_undo=mask_undo, first_n_ws_key=state_first_n_ws_key)
beam_candidates, candidate_inf_states, traversed_lists = follower.state_factored_search(beam_size, 1, load_next_minibatch=not include_gold, mask_undo=mask_undo, first_n_ws_key=state_first_n_ws_key)
else:
beam_candidates = follower.beam_search(beam_size, load_next_minibatch=not include_gold, mask_undo=mask_undo)
beam_candidates, candidate_inf_states, traversed_lists = follower.beam_search(beam_size, load_next_minibatch=not include_gold, mask_undo=mask_undo)

if include_gold:
assert len(gold_candidates) == len(beam_candidates)
Expand All @@ -67,14 +69,27 @@ def run_rational_follower(
speaker_scored_candidates, _ = speaker._score_obs_actions_and_instructions(cand_obs, cand_actions, cand_instr, feedback='teacher')
assert len(speaker_scored_candidates) == sum(len(l) for l in beam_candidates)
start_index = 0
for instance_candidates in beam_candidates:
for instance_index, instance_candidates in enumerate(beam_candidates):
for i, candidate in enumerate(instance_candidates):
speaker_scored_candidate = speaker_scored_candidates[start_index + i]
assert candidate['instr_id'] == speaker_scored_candidate['instr_id']
candidate['follower_score'] = candidate['score']
candidate['speaker_score'] = speaker_scored_candidate['score']
# Delete the unnecessary keys not needed for later processing
del candidate['observations']
if physical_traversal:
last_traversed = traversed_lists[instance_index][-1]
candidate_inf_state = candidate_inf_states[instance_index][i]
path_from_last_to_next = least_common_viewpoint_path(last_traversed, candidate_inf_state)
assert path_from_last_to_next[0].world_state.viewpointId == last_traversed.world_state.viewpointId
assert path_from_last_to_next[-1].world_state.viewpointId == candidate_inf_state.world_state.viewpointId

inf_traj = traversed_lists[instance_index] + path_from_last_to_next[1:]
physical_trajectory = [path_element_from_observation(inf_state.observation)
for inf_state in inf_traj]
# make sure the viewpointIds match
assert physical_trajectory[-1][0] == candidate['trajectory'][-1][0]
candidate['trajectory'] = physical_trajectory
if compute_oracle:
candidate['eval_result'] = evaluator._score_item(candidate['instr_id'], candidate['trajectory'])._asdict()
start_index += len(instance_candidates)
Expand Down Expand Up @@ -182,7 +197,9 @@ def validate_entry_point(args):
eval_file=eval_file, compute_oracle=args.compute_oracle,
mask_undo=args.mask_undo,
state_factored_search=args.state_factored_search,
state_first_n_ws_key=args.state_first_n_ws_key)
state_first_n_ws_key=args.state_first_n_ws_key,
physical_traversal=args.physical_traversal,
)
pprint.pprint(accuracies_by_weight)
pprint.pprint({w:sorted(d.items()) for w, d in index_counts_by_weight.items()})
weight, score_summary = max(accuracies_by_weight.items(), key=lambda pair: pair[1]['success_rate'])
Expand All @@ -203,6 +220,7 @@ def make_arg_parser():
parser.add_argument("--mask_undo", action='store_true')
parser.add_argument("--state_factored_search", action='store_true')
parser.add_argument("--state_first_n_ws_key", type=int, default=4)
parser.add_argument("--physical_traversal", action='store_true')

return parser

Expand Down

0 comments on commit ba393a1

Please sign in to comment.