From ba393a10480e9126c413b3cbd1c841105faefca6 Mon Sep 17 00:00:00 2001 From: Daniel Fried Date: Mon, 21 May 2018 18:34:16 -0700 Subject: [PATCH] add physical traversal through beam states --- tasks/R2R/follower.py | 90 ++++++++++++++++++++++++++++++---- tasks/R2R/rational_follower.py | 28 +++++++++-- 2 files changed, 103 insertions(+), 15 deletions(-) diff --git a/tasks/R2R/follower.py b/tasks/R2R/follower.py index 2476736..c328633 100644 --- a/tasks/R2R/follower.py +++ b/tasks/R2R/follower.py @@ -18,6 +18,17 @@ InferenceState = namedtuple("InferenceState", "prev_inference_state, world_state, observation, flat_index, last_action, last_action_embedding, action_count, score, h_t, c_t, last_alpha") +Cons = namedtuple("Cons", "first, rest") + +def cons_to_list(cons): + l = [] + while True: + l.append(cons.first) + cons = cons.rest + if cons is None: + break + return l + def backchain_inference_states(last_inference_state): states = [] observations = [] @@ -38,6 +49,29 @@ def backchain_inference_states(last_inference_state): scores.append(last_score) return list(reversed(states)), list(reversed(observations)), list(reversed(actions))[1:], list(reversed(scores))[1:], list(reversed(attentions))[1:] # exclude start action +def least_common_viewpoint_path(inf_state_a, inf_state_b): + # return inference states traversing from A to X, then from Y to B, + # where X and Y are the least common ancestors of A and B respectively that share a viewpointId + path_to_b_by_viewpoint = { + } + b = inf_state_b + b_stack = Cons(b, None) + while b is not None: + path_to_b_by_viewpoint[b.world_state.viewpointId] = b_stack + b = b.prev_inference_state + b_stack = Cons(b, b_stack) + a = inf_state_a + path_from_a = [a] + while a is not None: + vp = a.world_state.viewpointId + if vp in path_to_b_by_viewpoint: + path_to_b = cons_to_list(path_to_b_by_viewpoint[vp]) + assert path_from_a[-1].world_state.viewpointId == path_to_b[0].world_state.viewpointId + return path_from_a + path_to_b[1:] + a = a.prev_inference_state + path_from_a.append(a) + raise AssertionError("no common ancestor found") + def batch_instructions_from_encoded(encoded_instructions, max_length, reverse=False, sort=False): # encoded_instructions: list of lists of token indices (should not be padded, or contain BOS or EOS tokens) #seq_tensor = np.array(encoded_instructions) @@ -302,7 +336,7 @@ def rollout(self): return self._rollout_with_loss() else: assert self.beam_size >= 1 - beams = self.beam_search(self.beam_size) + beams, _, _ = self.beam_search(self.beam_size) return [beam[0] for beam in beams] def _score_obs_actions_and_instructions(self, path_obs, path_actions, encoded_instructions): @@ -680,7 +714,8 @@ def beam_search(self, beam_size, load_next_minibatch=True, mask_undo=False): 'attentions': path_attentions }) trajs.append(this_trajs) - return trajs + traversed_lists = None # todo + return trajs, completed, traversed_lists def state_factored_search(self, completion_size, successor_size, load_next_minibatch=True, mask_undo=False, first_n_ws_key=4): assert self.env.beam_size >= successor_size @@ -702,19 +737,47 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib state_cache = [ {ws[0][0:first_n_ws_key]: (InferenceState(prev_inference_state=None, - world_state=ws[0], - observation=o[0], - flat_index=None, - last_action=-1, - last_action_embedding=self.decoder.u_begin.view(-1), - action_count=0, - score=0.0, h_t=h_t[i], c_t=c_t[i], last_alpha=None), True)} + world_state=ws[0], + observation=o[0], + flat_index=None, + last_action=-1, + last_action_embedding=self.decoder.u_begin.view(-1), + action_count=0, + score=0.0, h_t=h_t[i], c_t=c_t[i], last_alpha=None), True)} for i, (ws, o) in enumerate(zip(world_states, initial_obs)) ] beams = [[inf_state for world_state, (inf_state, expanded) in sorted(instance_cache.items())] for instance_cache in state_cache] # sorting is a noop here since each instance_cache should only contain one + + # traversed_lists = None + # list of inference states containing states in order of the states being expanded + last_expanded_list = [] + traversed_lists = [] + for beam in beams: + assert len(beam) == 1 + first_state = beam[0] + last_expanded_list.append(first_state) + traversed_lists.append([first_state]) + + def update_traversed_lists(new_visited_inf_states): + assert len(new_visited_inf_states) == len(last_expanded_list) + assert len(new_visited_inf_states) == len(traversed_lists) + + for instance_index, instance_states in enumerate(new_visited_inf_states): + last_expanded = last_expanded_list[instance_index] + # todo: if this passes, shouldn't need traversed_lists + assert last_expanded.world_state.viewpointId == traversed_lists[instance_index][-1].world_state.viewpointId + for inf_state in instance_states: + path_from_last_to_next = least_common_viewpoint_path(last_expanded, inf_state) + # path_from_last should include last_expanded's world state as the first element, so check and drop that + assert path_from_last_to_next[0].world_state.viewpointId == last_expanded.world_state.viewpointId + assert path_from_last_to_next[-1].world_state.viewpointId == inf_state.world_state.viewpointId + traversed_lists[instance_index].extend(path_from_last_to_next[1:]) + last_expanded = inf_state + last_expanded_list[instance_index] = last_expanded + # Do a sequence rollout and calculate the loss while any(len(comp) < completion_size for comp in completed): beam_indices = [] @@ -873,6 +936,7 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib successor_obs = self.env.observe(world_states, beamed=True) beams = structured_map(lambda inf_state, obs: inf_state._replace(observation=obs), beams, successor_obs, nested=True) + update_traversed_lists(beams) completed_list = [] for this_completed in completed: @@ -884,6 +948,10 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib completed_obs = self.env.observe(completed_ws, beamed=True) completed_list = structured_map(lambda inf_state, obs: inf_state._replace(observation=obs), completed_list, completed_obs, nested=True) + # TODO: consider moving observations and this update earlier so that we don't have to traverse as far back + update_traversed_lists(completed_list) + + # TODO: sanity check the traversed lists here trajs = [] for this_completed in completed_list: @@ -907,7 +975,9 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib 'attentions': path_attentions }) trajs.append(this_trajs) - return trajs + # completed_list: list of lists of final inference states corresponding to the candidates, one list per instance + # traversed_lists: list of "physical states" that the robot has explored, one per instance + return trajs, completed_list, traversed_lists def set_beam_size(self, beam_size): if self.env.beam_size < beam_size: diff --git a/tasks/R2R/rational_follower.py b/tasks/R2R/rational_follower.py index b2e724f..6e248e4 100644 --- a/tasks/R2R/rational_follower.py +++ b/tasks/R2R/rational_follower.py @@ -8,6 +8,8 @@ import MatterSim import env +from follower import least_common_viewpoint_path, path_element_from_observation + import numpy as np from collections import namedtuple, Counter @@ -18,7 +20,7 @@ def run_rational_follower( envir, evaluator, follower, speaker, beam_size, include_gold=False, output_file=None, eval_file=None, compute_oracle=False, mask_undo=False, state_factored_search=False, - state_first_n_ws_key=4): + state_first_n_ws_key=4, physical_traversal=False): follower.env = envir envir.reset_epoch() @@ -46,9 +48,9 @@ def run_rational_follower( follower.feedback = feedback_method if state_factored_search: - beam_candidates = follower.state_factored_search(beam_size, 1, load_next_minibatch=not include_gold, mask_undo=mask_undo, first_n_ws_key=state_first_n_ws_key) + beam_candidates, candidate_inf_states, traversed_lists = follower.state_factored_search(beam_size, 1, load_next_minibatch=not include_gold, mask_undo=mask_undo, first_n_ws_key=state_first_n_ws_key) else: - beam_candidates = follower.beam_search(beam_size, load_next_minibatch=not include_gold, mask_undo=mask_undo) + beam_candidates, candidate_inf_states, traversed_lists = follower.beam_search(beam_size, load_next_minibatch=not include_gold, mask_undo=mask_undo) if include_gold: assert len(gold_candidates) == len(beam_candidates) @@ -67,7 +69,7 @@ def run_rational_follower( speaker_scored_candidates, _ = speaker._score_obs_actions_and_instructions(cand_obs, cand_actions, cand_instr, feedback='teacher') assert len(speaker_scored_candidates) == sum(len(l) for l in beam_candidates) start_index = 0 - for instance_candidates in beam_candidates: + for instance_index, instance_candidates in enumerate(beam_candidates): for i, candidate in enumerate(instance_candidates): speaker_scored_candidate = speaker_scored_candidates[start_index + i] assert candidate['instr_id'] == speaker_scored_candidate['instr_id'] @@ -75,6 +77,19 @@ def run_rational_follower( candidate['speaker_score'] = speaker_scored_candidate['score'] # Delete the unnecessary keys not needed for later processing del candidate['observations'] + if physical_traversal: + last_traversed = traversed_lists[instance_index][-1] + candidate_inf_state = candidate_inf_states[instance_index][i] + path_from_last_to_next = least_common_viewpoint_path(last_traversed, candidate_inf_state) + assert path_from_last_to_next[0].world_state.viewpointId == last_traversed.world_state.viewpointId + assert path_from_last_to_next[-1].world_state.viewpointId == candidate_inf_state.world_state.viewpointId + + inf_traj = traversed_lists[instance_index] + path_from_last_to_next[1:] + physical_trajectory = [path_element_from_observation(inf_state.observation) + for inf_state in inf_traj] + # make sure the viewpointIds match + assert physical_trajectory[-1][0] == candidate['trajectory'][-1][0] + candidate['trajectory'] = physical_trajectory if compute_oracle: candidate['eval_result'] = evaluator._score_item(candidate['instr_id'], candidate['trajectory'])._asdict() start_index += len(instance_candidates) @@ -182,7 +197,9 @@ def validate_entry_point(args): eval_file=eval_file, compute_oracle=args.compute_oracle, mask_undo=args.mask_undo, state_factored_search=args.state_factored_search, - state_first_n_ws_key=args.state_first_n_ws_key) + state_first_n_ws_key=args.state_first_n_ws_key, + physical_traversal=args.physical_traversal, + ) pprint.pprint(accuracies_by_weight) pprint.pprint({w:sorted(d.items()) for w, d in index_counts_by_weight.items()}) weight, score_summary = max(accuracies_by_weight.items(), key=lambda pair: pair[1]['success_rate']) @@ -203,6 +220,7 @@ def make_arg_parser(): parser.add_argument("--mask_undo", action='store_true') parser.add_argument("--state_factored_search", action='store_true') parser.add_argument("--state_first_n_ws_key", type=int, default=4) + parser.add_argument("--physical_traversal", action='store_true') return parser