Skip to content

Commit

Permalink
add attention exporting in rational follower
Browse files Browse the repository at this point in the history
  • Loading branch information
ronghanghu committed May 15, 2018
1 parent 4f4207d commit 684c54f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 14 deletions.
23 changes: 15 additions & 8 deletions tasks/R2R/follower.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#from env import FOLLOWER_MODEL_ACTIONS, FOLLOWER_ENV_ACTIONS, IGNORE_ACTION_INDEX, LEFT_ACTION_INDEX, RIGHT_ACTION_INDEX, START_ACTION_INDEX, END_ACTION_INDEX, FORWARD_ACTION_INDEX, index_action_tuple

InferenceState = namedtuple("InferenceState", "prev_inference_state, world_state, observation, flat_index, last_action, last_action_embedding, action_count, score, h_t, c_t")
InferenceState = namedtuple("InferenceState", "prev_inference_state, world_state, observation, flat_index, last_action, last_action_embedding, action_count, score, h_t, c_t, last_alpha")

def backchain_inference_states(last_inference_state):
states = []
Expand All @@ -25,16 +25,18 @@ def backchain_inference_states(last_inference_state):
inf_state = last_inference_state
scores = []
last_score = None
attentions = []
while inf_state is not None:
states.append(inf_state.world_state)
observations.append(inf_state.observation)
actions.append(inf_state.last_action)
attentions.append(inf_state.last_alpha)
if last_score is not None:
scores.append(last_score - inf_state.score)
last_score = inf_state.score
inf_state = inf_state.prev_inference_state
scores.append(last_score)
return list(reversed(states)), list(reversed(observations)), list(reversed(actions))[1:], list(reversed(scores))[1:] # exclude start action
return list(reversed(states)), list(reversed(observations)), list(reversed(actions))[1:], list(reversed(scores))[1:], list(reversed(attentions))[1:] # exclude start action

def batch_instructions_from_encoded(encoded_instructions, max_length, reverse=False, sort=False):
# encoded_instructions: list of lists of token indices (should not be padded, or contain BOS or EOS tokens)
Expand Down Expand Up @@ -526,7 +528,7 @@ def beam_search(self, beam_size, load_next_minibatch=True, mask_undo=False):
last_action=-1,
last_action_embedding=self.decoder.u_begin.view(-1),
action_count=0,
score=0.0, h_t=None, c_t=None)]
score=0.0, h_t=None, c_t=None, last_alpha=None)]
for i, (ws, o) in enumerate(zip(world_states, obs))
]

Expand Down Expand Up @@ -597,7 +599,8 @@ def beam_search(self, beam_size, load_next_minibatch=True, mask_undo=False):
last_action=action_index,
last_action_embedding=all_u_t[flat_index, action_index].detach(),
action_count=inf_state.action_count + 1,
score=float(inf_state.score + action_score), h_t=None, c_t=None)
score=float(inf_state.score + action_score), h_t=None, c_t=None,
last_alpha=alpha[flat_index].data)
)
start_index = end_index
successors = sorted(successors, key=lambda t: t.score, reverse=True)[:beam_size]
Expand Down Expand Up @@ -660,7 +663,7 @@ def beam_search(self, beam_size, load_next_minibatch=True, mask_undo=False):
assert this_completed
this_trajs = []
for inf_state in sorted(this_completed, key=lambda t: t.score, reverse=True)[:beam_size]:
path_states, path_observations, path_actions, path_scores = backchain_inference_states(inf_state)
path_states, path_observations, path_actions, path_scores, path_attentions = backchain_inference_states(inf_state)
# this will have messed-up headings for (at least some) starting locations because of
# discretization, so read from the observations instead
## path = [(obs.viewpointId, state.heading, state.elevation)
Expand All @@ -674,6 +677,7 @@ def beam_search(self, beam_size, load_next_minibatch=True, mask_undo=False):
'actions': path_actions,
'score': inf_state.score,
'scores': path_scores,
'attentions': path_attentions
})
trajs.append(this_trajs)
return trajs
Expand Down Expand Up @@ -704,7 +708,7 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib
last_action=-1,
last_action_embedding=self.decoder.u_begin.view(-1),
action_count=0,
score=0.0, h_t=h_t[i], c_t=c_t[i]), True)}
score=0.0, h_t=h_t[i], c_t=c_t[i], last_alpha=None), True)}
for i, (ws, o) in enumerate(zip(world_states, initial_obs))
]

Expand Down Expand Up @@ -781,7 +785,9 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib
last_action=action_index,
last_action_embedding=all_u_t[flat_index, action_index].detach(),
action_count=inf_state.action_count + 1,
score=inf_state.score + action_score, h_t=h_t[flat_index], c_t=c_t[flat_index])
score=inf_state.score + action_score,
h_t=h_t[flat_index], c_t=c_t[flat_index],
last_alpha=alpha[flat_index].data)
)
start_index = end_index
successors = sorted(successors, key=lambda t: t.score, reverse=True)
Expand Down Expand Up @@ -884,7 +890,7 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib
assert this_completed
this_trajs = []
for inf_state in this_completed:
path_states, path_observations, path_actions, path_scores = backchain_inference_states(inf_state)
path_states, path_observations, path_actions, path_scores, path_attentions = backchain_inference_states(inf_state)
# this will have messed-up headings for (at least some) starting locations because of
# discretization, so read from the observations instead
## path = [(obs.viewpointId, state.heading, state.elevation)
Expand All @@ -898,6 +904,7 @@ def state_factored_search(self, completion_size, successor_size, load_next_minib
'actions': path_actions,
'score': inf_state.score,
'scores': path_scores,
'attentions': path_attentions
})
trajs.append(this_trajs)
return trajs
Expand Down
2 changes: 2 additions & 0 deletions tasks/R2R/rational_follower.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def run_rational_follower(
candidate['actions'] = candidate['actions']
candidate['scored_actions'] = list(zip(candidate['actions'], candidate['scores']))
candidate['instruction'] = envir.tokenizer.decode_sentence(candidate['instr_encoding'], break_on_eos=False, join=True)
if 'attentions' in candidate:
candidate['attentions'] = [list(tens) for tens in candidate['attentions']]
del candidate['instr_encoding']
del candidate['trajectory']
candidate['rank'] = i
Expand Down
2 changes: 2 additions & 0 deletions tasks/R2R/rational_speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def generate_and_score_candidates(envir, speaker, follower, n_candidates, includ
candidate['speaker_score'] = candidate['score']
candidate['follower_score'] = follower_scored_candidate['score']
candidate['actions'] = follower_scored_candidate['actions']
if 'attentions' in candidate:
candidate['attentions'] = [list(tens) for tens in candidate['attentions']]
assert np.allclose(np.sum(follower_scored_candidate['scores']), follower_scored_candidate['score'])
start_index += len(instance_candidates)
assert utils.all_equal([i['instr_id'] for i in instance_candidates])
Expand Down
17 changes: 11 additions & 6 deletions tasks/R2R/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,23 @@
from utils import vocab_pad_idx, vocab_bos_idx, vocab_eos_idx, flatten, try_cuda
from follower import batch_instructions_from_encoded

InferenceState = namedtuple("InferenceState", "prev_inference_state, flat_index, last_word, word_count, score")
InferenceState = namedtuple("InferenceState", "prev_inference_state, flat_index, last_word, word_count, score, last_alpha")

def backchain_inference_states(last_inference_state):
word_indices = []
inf_state = last_inference_state
scores = []
last_score = None
attentions = []
while inf_state is not None:
word_indices.append(inf_state.last_word)
attentions.append(inf_state.last_alpha)
if last_score is not None:
scores.append(last_score - inf_state.score)
last_score = inf_state.score
inf_state = inf_state.prev_inference_state
scores.append(last_score)
return list(reversed(word_indices))[1:], list(reversed(scores))[1:] # exclude BOS
return list(reversed(word_indices))[1:], list(reversed(scores))[1:], list(reversed(attentions))[1:] # exclude BOS

class Seq2SeqSpeaker(object):
feedback_options = ['teacher', 'argmax', 'sample']
Expand Down Expand Up @@ -228,7 +230,8 @@ def beam_search(self, beam_size, path_obs, path_actions):
flat_index=i,
last_word=vocab_bos_idx,
word_count=0,
score=0.0)]
score=0.0,
last_alpha=None)]
for i in range(batch_size)
]

Expand Down Expand Up @@ -269,7 +272,8 @@ def beam_search(self, beam_size, path_obs, path_actions):
flat_index=flat_index,
last_word=word_index,
word_count=inf_state.word_count + 1,
score=inf_state.score + word_score)
score=inf_state.score + word_score,
last_alpha=alpha[flat_index].data)
)
start_index = end_index
successors = sorted(successors, key=lambda t: t.score, reverse=True)[:beam_size]
Expand Down Expand Up @@ -302,13 +306,14 @@ def beam_search(self, beam_size, path_obs, path_actions):
this_completed = completed[perm_index]
instr_id = start_obs[perm_index]['instr_id']
for inf_state in sorted(this_completed, key=lambda t: t.score, reverse=True)[:beam_size]:
word_indices, scores = backchain_inference_states(inf_state)
word_indices, scores, attentions = backchain_inference_states(inf_state)
this_outputs.append({
'instr_id': instr_id,
'word_indices': word_indices,
'score': inf_state.score,
'scores': scores,
'words': self.env.tokenizer.decode_sentence(word_indices, break_on_eos=True, join=False)
'words': self.env.tokenizer.decode_sentence(word_indices, break_on_eos=True, join=False),
'attentions': attentions,
})
return outputs

Expand Down

0 comments on commit 684c54f

Please sign in to comment.