diff --git a/tasks/R2R/validate.py b/tasks/R2R/validate.py index 2cb5e34..fca0129 100644 --- a/tasks/R2R/validate.py +++ b/tasks/R2R/validate.py @@ -17,6 +17,15 @@ def validate_entry_point(args): results = agent.test(use_dropout=False, feedback='argmax', beam_size=args.beam_size) score_summary, _ = evaluator.score_results(results) + if args.eval_file: + eval_file = "{}_{}.json".format(args.eval_file, env_name) + eval_results = [] + for instr_id, result in results.items(): + eval_results.append( + {'instr_id': instr_id, 'trajectory': result['trajectory']}) + with open(eval_file, 'w') as f: + utils.pretty_json_dump(eval_results, f) + # TODO: testing code, remove # score_summary_direct, _ = evaluator.score_results(agent.results) # assert score_summary == score_summary_direct @@ -29,6 +38,7 @@ def make_arg_parser(): parser.add_argument("model_prefix") parser.add_argument("--beam_size", type=int, default=1) parser.add_argument("--batch_size", type=int, default=100) + parser.add_argument("--eval_file") return parser if __name__ == "__main__":