Skip to content

Commit

Permalink
merge rational speaker into panorama
Browse files Browse the repository at this point in the history
  • Loading branch information
ronghanghu committed May 10, 2018
1 parent b681a2a commit 4f4207d
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 39 deletions.
39 changes: 28 additions & 11 deletions tasks/R2R/rational_speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

SpeakerCandidate = namedtuple("SpeakerCandidate", "instr_id, observations, actions, instr_encoding, follower_score, speaker_score")

def run_rational_speaker(envir, evaluator, speaker, follower, beam_size, include_gold=False, output_file=None):
def generate_and_score_candidates(envir, speaker, follower, n_candidates, include_gold=False):
follower.env = envir
speaker.env = envir
envir.reset_epoch()
Expand All @@ -31,6 +31,7 @@ def run_rational_speaker(envir, evaluator, speaker, follower, beam_size, include

candidate_lists_by_instr_id = {}

num_candidates_per_instance = []
looped = False
batch_idx = 0
while True:
Expand All @@ -42,7 +43,7 @@ def run_rational_speaker(envir, evaluator, speaker, follower, beam_size, include
else:
gold_candidates = []

beam_candidates = speaker.beam_search(beam_size, path_obs, path_actions)
beam_candidates = speaker.beam_search(n_candidates, path_obs, path_actions)

if include_gold:
assert len(gold_candidates) == len(beam_candidates)
Expand All @@ -55,6 +56,7 @@ def run_rational_speaker(envir, evaluator, speaker, follower, beam_size, include
cand_obs = []
cand_actions = []
for beam_index, this_beam in enumerate(beam_candidates):
num_candidates_per_instance.append(len(this_beam))
for candidate in this_beam:
cand_obs.append(path_obs[beam_index])
cand_actions.append(path_actions[beam_index])
Expand Down Expand Up @@ -85,6 +87,12 @@ def run_rational_speaker(envir, evaluator, speaker, follower, beam_size, include
if looped:
break

print("average distinct candidates per instance: {}".format(np.mean(num_candidates_per_instance)))

return candidate_lists_by_instr_id


def predict_from_candidates(candidate_lists_by_instr_id, speaker_weights):
speaker_scores = [cand['speaker_score']
for lst in candidate_lists_by_instr_id.values()
for cand in lst]
Expand All @@ -96,7 +104,6 @@ def run_rational_speaker(envir, evaluator, speaker, follower, beam_size, include
follower_std = np.std(follower_scores)

results_by_weight = {}
index_counts_by_weight = {}

for speaker_weight in np.arange(0, 20 + 1) / 20.0:
results = {}
Expand All @@ -110,10 +117,21 @@ def run_rational_speaker(envir, evaluator, speaker, follower, beam_size, include
results[instr_id] = best_cand
index_count[best_ix] += 1

score_summary, _ = evaluator.score_results(results)
results_by_weight[speaker_weight] = results

return results_by_weight


def run_rational_speaker(envir, speaker_evaluator, speaker, follower, n_candidates, include_gold=False, output_file=None):
candidate_lists_by_instr_id = generate_and_score_candidates(envir, speaker, follower, n_candidates)

speaker_weights = np.arange(0, 20 + 1) / 20.0
results_by_weight = predict_from_candidates(candidate_lists_by_instr_id, speaker_weights)

results_by_weight[speaker_weight] = score_summary
index_counts_by_weight[speaker_weight] = index_count
scores_by_weight = {}
for speaker_weight, results in results_by_weight.items():
score_summary, _ = speaker_evaluator.score_results(results)
scores_by_weight[speaker_weight] = score_summary

if output_file:
with open(output_file, 'w') as f:
Expand All @@ -125,7 +143,7 @@ def run_rational_speaker(envir, evaluator, speaker, follower, beam_size, include
candidate['gold'] = (include_gold and i == 0)
utils.pretty_json_dump(candidate_lists_by_instr_id, f)

return results_by_weight, index_counts_by_weight
return scores_by_weight, results_by_weight

def validate_entry_point(args):
follower, follower_train_env, follower_val_envs = train.train_setup(args, args.batch_size)
Expand All @@ -142,7 +160,7 @@ def validate_entry_point(args):
output_file = "{}_{}.json".format(args.output_file, env_name)
else:
output_file = None
results_by_weight, index_counts_by_weight = run_rational_speaker(
scores_by_weight, _ = run_rational_speaker(
env,
evaluator,
speaker,
Expand All @@ -151,9 +169,8 @@ def validate_entry_point(args):
include_gold=args.include_gold,
output_file=output_file
)
pprint.pprint(results_by_weight)
pprint.pprint({w:sorted(d.items()) for w, d in index_counts_by_weight.items()})
weight, score_summary = max(results_by_weight.items(), key=lambda pair: pair[1]['bleu'])
pprint.pprint(scores_by_weight)
weight, score_summary = max(scores_by_weight.items(), key=lambda pair: pair[1]['bleu'])
print("max success_rate with weight: {}".format(weight))
for metric,val in score_summary.items():
print("{} {}\t{}".format(env_name, metric, val))
Expand Down
96 changes: 68 additions & 28 deletions tasks/R2R/self_play_from_speaker.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,94 @@
import argparse
import utils
import train
import train_speaker
from speaker import Seq2SeqSpeaker
from follower import Seq2SeqAgent
import rational_speaker

def selfplay_setup(args):
def selfplay_speaker_setup(args):
train_splits = ['train']
pred_splits = args.pred_splits
vocab = train_speaker.TRAIN_VOCAB

train_env, pred_envs, encoder, decoder = train_speaker.make_env_and_models(args, vocab, train_splits, pred_splits, test_instruction_limit=1)
train_env, pred_envs, encoder, decoder = train_speaker.make_env_and_models(
args, vocab, train_splits, pred_splits, test_instruction_limit=1)
agent = Seq2SeqSpeaker(train_env, "", encoder, decoder, train_speaker.MAX_INSTRUCTION_LENGTH)
return agent, train_env, pred_envs


def selfplay_follower_setup(args):
train_splits = ['train']
pred_splits = args.pred_splits
vocab = train.TRAIN_VOCAB

train_env, pred_envs, encoder, decoder = train.make_env_and_models(
args, vocab, train_splits, pred_splits, batch_size=args.batch_size
)
agent = Seq2SeqAgent(train_env, "", encoder, decoder, train.max_episode_len, max_instruction_length=train.MAX_INPUT_LENGTH)
return agent, train_env, pred_envs


def entry_point(args):
agent, train_env, val_envs = selfplay_setup(args)
agent.load(args.model_prefix)
speaker, train_env, val_envs = selfplay_speaker_setup(args)
speaker.load(args.speaker_model_prefix)

assert (args.rational_speaker_weights is None) == (args.follower_model_prefix is None),\
"must pass both --rational_speaker_weight and --follower_model_prefix, or neither"

pragmatic_speaker = (args.follower_model_prefix is not None)

if pragmatic_speaker:
follower, train_env_follower, val_envs_followers = selfplay_follower_setup(args)
follower.load(args.follower_model_prefix)
else:
follower = None

for env_name, (env, evaluator) in val_envs.items():
agent.env = env
agent.env.print_progress = True

## gold
# gold_results = agent.test(use_dropout=False, feedback='teacher', allow_cheat=True)
# gold_score_summary = evaluator.score_results(gold_results, verbose=False)
#
# for metric,val in gold_score_summary.items():
# print("gold {} {}\t{}".format(env_name, metric, val))
#
# if args.gold_results_output_file:
# fname = "{}_{}.json".format(args.gold_results_output_file, env_name)
# with open(fname, 'w') as f:
# utils.pretty_json_dump(gold_results, f)
speaker.env = env
speaker.env.print_progress = True

## predicted
pred_results = agent.test(use_dropout=False, feedback='argmax')
pred_score_summary, pred_replaced_gt = evaluator.score_results(pred_results, verbose=False)
if pragmatic_speaker:
candidate_lists_by_instr_id = rational_speaker.generate_and_score_candidates(
env, speaker, follower, args.rational_speaker_n_candidates,
include_gold=False
)
results_by_weight = rational_speaker.predict_from_candidates(candidate_lists_by_instr_id, args.rational_speaker_weights)
results_by_name = {
'rational_speaker_{}'.format(speaker_weight): results
for speaker_weight, results in results_by_weight.items()
}
else:
pred_results = speaker.test(use_dropout=False, feedback='argmax')
results_by_name = {'literal_speaker' : pred_results}

for metric,val in pred_score_summary.items():
print("pred {} {}\t{}".format(env_name, metric, val))
for name, pred_results in results_by_name.items():
pred_score_summary, pred_replaced_gt = evaluator.score_results(pred_results, verbose=False)

for metric,val in pred_score_summary.items():
print("pred {} {} {}\t{}".format(name, env_name, metric, val))

fname = "{}_{}_{}.json".format(args.pred_results_output_file, name, env_name)
with open(fname, 'w') as f:
utils.pretty_json_dump(pred_replaced_gt, f)

fname = "{}_{}.json".format(args.pred_results_output_file, env_name)
with open(fname, 'w') as f:
utils.pretty_json_dump(pred_replaced_gt, f)

def make_arg_parser():
parser = train_speaker.make_arg_parser()
parser.add_argument("model_prefix")
#parser = train_speaker.make_arg_parser()
# todo: hack, this only works because the follower has extra parameters that the speaker lacks!
parser = train.make_arg_parser()
parser.add_argument("speaker_model_prefix")
parser.add_argument("pred_results_output_file")
parser.add_argument("--batch_size", type=int, default=20)
parser.add_argument("--pred_splits", nargs="+", default=["train_selfplay"])
#parser.add_argument("--beam_size", type=int, default=1)

# for rational self-play generation
parser.add_argument("--follower_model_prefix",
help="generate data from a rational speaker (must also pass --rational_speaker_weights")
parser.add_argument("--rational_speaker_weights", type=float, nargs="+",
help="list of speaker weights in range [0.0, 1.0] to use with rational speaker (must also pass follower_model_prefix)")
parser.add_argument("--rational_speaker_n_candidates", type=int, default=40)
return parser

if __name__ == "__main__":
Expand Down

0 comments on commit 4f4207d

Please sign in to comment.