diff --git a/tasks/R2R/train.py b/tasks/R2R/train.py index 17bdac7..a7e2db0 100644 --- a/tasks/R2R/train.py +++ b/tasks/R2R/train.py @@ -221,14 +221,15 @@ def train_setup(args, batch_size=BATCH_SIZE): return agent, train_env, val_envs -def test_setup(args, batch_size=BATCH_SIZE): - train_env, test_envs, encoder, decoder = make_env_and_models( - args, TRAINVAL_VOCAB, ['train', 'val_seen', 'val_unseen'], ['test'], - batch_size=batch_size) - agent = Seq2SeqAgent( - None, "", encoder, decoder, max_episode_len, - max_instruction_length=MAX_INPUT_LENGTH) - return agent, train_env, test_envs +# Test set prediction will be handled separately +# def test_setup(args, batch_size=BATCH_SIZE): +# train_env, test_envs, encoder, decoder = make_env_and_models( +# args, TRAINVAL_VOCAB, ['train', 'val_seen', 'val_unseen'], ['test'], +# batch_size=batch_size) +# agent = Seq2SeqAgent( +# None, "", encoder, decoder, max_episode_len, +# max_instruction_length=MAX_INPUT_LENGTH) +# return agent, train_env, test_envs def train_val(args): @@ -237,21 +238,22 @@ def train_val(args): train(args, train_env, agent, val_envs=val_envs) -def test_submission(args): - ''' Train on combined training and validation sets, and generate test - submission. ''' - agent, train_env, test_envs = test_setup(args) - train(args, train_env, agent) - - test_env = test_envs['test'] - agent.env = test_env - - agent.results_path = '%s%s_%s_iter_%d.json' % ( - RESULT_DIR, get_model_prefix(args, train_env.image_features_list), - 'test', args.n_iters) - agent.test(use_dropout=False, feedback='argmax') - if not args.no_save: - agent.write_results() +# Test set prediction will be handled separately +# def test_submission(args): +# ''' Train on combined training and validation sets, and generate test +# submission. ''' +# agent, train_env, test_envs = test_setup(args) +# train(args, train_env, agent) +# +# test_env = test_envs['test'] +# agent.env = test_env +# +# agent.results_path = '%s%s_%s_iter_%d.json' % ( +# RESULT_DIR, get_model_prefix(args, train_env.image_features_list), +# 'test', args.n_iters) +# agent.test(use_dropout=False, feedback='argmax') +# if not args.no_save: +# agent.write_results() def make_arg_parser(): diff --git a/tasks/R2R/train_speaker.py b/tasks/R2R/train_speaker.py index f9724f7..ec17ddf 100644 --- a/tasks/R2R/train_speaker.py +++ b/tasks/R2R/train_speaker.py @@ -220,13 +220,14 @@ def train_setup(args): return agent, train_env, val_envs -def test_setup(args): - train_env, test_envs, encoder, decoder = make_env_and_models( - args, TRAINVAL_VOCAB, ['train', 'val_seen', 'val_unseen'], ['test']) - agent = Seq2SeqSpeaker( - None, "", encoder, decoder, MAX_INSTRUCTION_LENGTH, - max_episode_len=max_episode_len) - return agent, train_env, test_envs +# Test set prediction will be handled separately +# def test_setup(args): +# train_env, test_envs, encoder, decoder = make_env_and_models( +# args, TRAINVAL_VOCAB, ['train', 'val_seen', 'val_unseen'], ['test']) +# agent = Seq2SeqSpeaker( +# None, "", encoder, decoder, MAX_INSTRUCTION_LENGTH, +# max_episode_len=max_episode_len) +# return agent, train_env, test_envs def train_val(args): @@ -235,21 +236,22 @@ def train_val(args): train(args, train_env, agent, val_envs=val_envs) -def test_submission(args): - ''' Train on combined training and validation sets, and generate test - submission. ''' - agent, train_env, test_envs = test_setup(args) - train(args, train_env, agent) - - test_env = test_envs['test'] - agent.env = test_env - - agent.results_path = '%s%s_%s_iter_%d.json' % ( - args.result_dir, get_model_prefix(args, train_env.image_features_list), - 'test', n_iters) - agent.test(use_dropout=False, feedback='argmax') - if not args.no_save: - agent.write_results() +# Test set prediction will be handled separately +# def test_submission(args): +# ''' Train on combined training and validation sets, and generate test +# submission. ''' +# agent, train_env, test_envs = test_setup(args) +# train(args, train_env, agent) +# +# test_env = test_envs['test'] +# agent.env = test_env +# +# agent.results_path = '%s%s_%s_iter_%d.json' % ( +# args.result_dir, get_model_prefix(args, train_env.image_features_list), +# 'test', n_iters) +# agent.test(use_dropout=False, feedback='argmax') +# if not args.no_save: +# agent.write_results() def make_arg_parser():