Skip to content

Commit

Permalink
add scripts to output evaluation files for greedy follower
Browse files Browse the repository at this point in the history
  • Loading branch information
ronghanghu committed May 21, 2018
1 parent 1d18d71 commit 8beac87
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tasks/R2R/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ def validate_entry_point(args):
results = agent.test(use_dropout=False, feedback='argmax', beam_size=args.beam_size)
score_summary, _ = evaluator.score_results(results)

if args.eval_file:
eval_file = "{}_{}.json".format(args.eval_file, env_name)
eval_results = []
for instr_id, result in results.items():
eval_results.append(
{'instr_id': instr_id, 'trajectory': result['trajectory']})
with open(eval_file, 'w') as f:
utils.pretty_json_dump(eval_results, f)

# TODO: testing code, remove
# score_summary_direct, _ = evaluator.score_results(agent.results)
# assert score_summary == score_summary_direct
Expand All @@ -29,6 +38,7 @@ def make_arg_parser():
parser.add_argument("model_prefix")
parser.add_argument("--beam_size", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--eval_file")
return parser

if __name__ == "__main__":
Expand Down

0 comments on commit 8beac87

Please sign in to comment.