Skip to content

Commit

Permalink
comment out test set submission procedure
Browse files Browse the repository at this point in the history
  • Loading branch information
ronghanghu committed Sep 17, 2018
1 parent 4773a71 commit 36d602e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 45 deletions.
48 changes: 25 additions & 23 deletions tasks/R2R/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,15 @@ def train_setup(args, batch_size=BATCH_SIZE):
return agent, train_env, val_envs


def test_setup(args, batch_size=BATCH_SIZE):
train_env, test_envs, encoder, decoder = make_env_and_models(
args, TRAINVAL_VOCAB, ['train', 'val_seen', 'val_unseen'], ['test'],
batch_size=batch_size)
agent = Seq2SeqAgent(
None, "", encoder, decoder, max_episode_len,
max_instruction_length=MAX_INPUT_LENGTH)
return agent, train_env, test_envs
# Test set prediction will be handled separately
# def test_setup(args, batch_size=BATCH_SIZE):
# train_env, test_envs, encoder, decoder = make_env_and_models(
# args, TRAINVAL_VOCAB, ['train', 'val_seen', 'val_unseen'], ['test'],
# batch_size=batch_size)
# agent = Seq2SeqAgent(
# None, "", encoder, decoder, max_episode_len,
# max_instruction_length=MAX_INPUT_LENGTH)
# return agent, train_env, test_envs


def train_val(args):
Expand All @@ -237,21 +238,22 @@ def train_val(args):
train(args, train_env, agent, val_envs=val_envs)


def test_submission(args):
''' Train on combined training and validation sets, and generate test
submission. '''
agent, train_env, test_envs = test_setup(args)
train(args, train_env, agent)

test_env = test_envs['test']
agent.env = test_env

agent.results_path = '%s%s_%s_iter_%d.json' % (
RESULT_DIR, get_model_prefix(args, train_env.image_features_list),
'test', args.n_iters)
agent.test(use_dropout=False, feedback='argmax')
if not args.no_save:
agent.write_results()
# Test set prediction will be handled separately
# def test_submission(args):
# ''' Train on combined training and validation sets, and generate test
# submission. '''
# agent, train_env, test_envs = test_setup(args)
# train(args, train_env, agent)
#
# test_env = test_envs['test']
# agent.env = test_env
#
# agent.results_path = '%s%s_%s_iter_%d.json' % (
# RESULT_DIR, get_model_prefix(args, train_env.image_features_list),
# 'test', args.n_iters)
# agent.test(use_dropout=False, feedback='argmax')
# if not args.no_save:
# agent.write_results()


def make_arg_parser():
Expand Down
46 changes: 24 additions & 22 deletions tasks/R2R/train_speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,14 @@ def train_setup(args):
return agent, train_env, val_envs


def test_setup(args):
train_env, test_envs, encoder, decoder = make_env_and_models(
args, TRAINVAL_VOCAB, ['train', 'val_seen', 'val_unseen'], ['test'])
agent = Seq2SeqSpeaker(
None, "", encoder, decoder, MAX_INSTRUCTION_LENGTH,
max_episode_len=max_episode_len)
return agent, train_env, test_envs
# Test set prediction will be handled separately
# def test_setup(args):
# train_env, test_envs, encoder, decoder = make_env_and_models(
# args, TRAINVAL_VOCAB, ['train', 'val_seen', 'val_unseen'], ['test'])
# agent = Seq2SeqSpeaker(
# None, "", encoder, decoder, MAX_INSTRUCTION_LENGTH,
# max_episode_len=max_episode_len)
# return agent, train_env, test_envs


def train_val(args):
Expand All @@ -235,21 +236,22 @@ def train_val(args):
train(args, train_env, agent, val_envs=val_envs)


def test_submission(args):
''' Train on combined training and validation sets, and generate test
submission. '''
agent, train_env, test_envs = test_setup(args)
train(args, train_env, agent)

test_env = test_envs['test']
agent.env = test_env

agent.results_path = '%s%s_%s_iter_%d.json' % (
args.result_dir, get_model_prefix(args, train_env.image_features_list),
'test', n_iters)
agent.test(use_dropout=False, feedback='argmax')
if not args.no_save:
agent.write_results()
# Test set prediction will be handled separately
# def test_submission(args):
# ''' Train on combined training and validation sets, and generate test
# submission. '''
# agent, train_env, test_envs = test_setup(args)
# train(args, train_env, agent)
#
# test_env = test_envs['test']
# agent.env = test_env
#
# agent.results_path = '%s%s_%s_iter_%d.json' % (
# args.result_dir, get_model_prefix(args, train_env.image_features_list),
# 'test', n_iters)
# agent.test(use_dropout=False, feedback='argmax')
# if not args.no_save:
# agent.write_results()


def make_arg_parser():
Expand Down

0 comments on commit 36d602e

Please sign in to comment.