Skip to content

Commit

Permalink
implement selfplay training
Browse files Browse the repository at this point in the history
  • Loading branch information
ronghanghu committed Sep 17, 2018
1 parent 36d602e commit 1506b29
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 16 deletions.
3 changes: 3 additions & 0 deletions tasks/R2R/data/download_precomputed_augmentation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/sh

wget http:https://people.eecs.berkeley.edu/~ronghang/projects/speaker_follower/data_augmentation/R2R_literal_speaker_data_augmentation_paths.json -P tasks/R2R/data/
65 changes: 49 additions & 16 deletions tasks/R2R/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def get_model_prefix(args, image_feature_list):
model_prefix = 'trainsub_' + model_prefix
if args.bidirectional:
model_prefix = model_prefix + "_bidirectional"
if args.use_pretraining:
model_prefix = model_prefix.replace(
'follower', 'follower_with_pretraining', 1)
return model_prefix


Expand All @@ -62,19 +65,14 @@ def filter_param(param_list):
return [p for p in param_list if p.requires_grad]


def train(args, train_env, agent, log_every=log_every, val_envs=None):
def train(args, train_env, agent, encoder_optimizer, decoder_optimizer,
n_iters, log_every=log_every, val_envs=None):
''' Train on training set, validating on both seen and unseen. '''

if val_envs is None:
val_envs = {}

print('Training with %s feedback' % args.feedback_method)
encoder_optimizer = optim.Adam(
filter_param(agent.encoder.parameters()), lr=learning_rate,
weight_decay=weight_decay)
decoder_optimizer = optim.Adam(
filter_param(agent.decoder.parameters()), lr=learning_rate,
weight_decay=weight_decay)

data_log = defaultdict(list)
start = time.time()
Expand All @@ -89,10 +87,10 @@ def make_path(n_iter):

best_metrics = {}
last_model_saved = {}
for idx in range(0, args.n_iters, log_every):
for idx in range(0, n_iters, log_every):
agent.env = train_env

interval = min(log_every, args.n_iters-idx)
interval = min(log_every, n_iters-idx)
iter = idx + interval
data_log['iteration'].append(iter)

Expand Down Expand Up @@ -151,8 +149,8 @@ def make_path(n_iter):
last_model_saved[key] = model_path

print(('%s (%d %d%%) %s' % (
timeSince(start, float(iter)/args.n_iters),
iter, float(iter)/args.n_iters*100, loss_str)))
timeSince(start, float(iter)/n_iters),
iter, float(iter)/n_iters*100, loss_str)))
for s in save_log:
print(s)

Expand All @@ -162,9 +160,9 @@ def make_path(n_iter):

df = pd.DataFrame(data_log)
df.set_index('iteration')
df_path = '%s%s_log.csv' % (
df_path = '%s%s_%s_log.csv' % (
PLOT_DIR, get_model_prefix(
args, train_env.image_features_list))
args, train_env.image_features_list), split_string)
df.to_csv(df_path)


Expand All @@ -173,6 +171,17 @@ def setup():
torch.cuda.manual_seed(1)


def make_more_train_env(args, train_vocab_path, train_splits,
batch_size=BATCH_SIZE):
setup()
image_features_list = ImageFeatures.from_args(args)
vocab = read_vocab(train_vocab_path)
tok = Tokenizer(vocab=vocab)
train_env = R2RBatch(image_features_list, batch_size=batch_size,
splits=train_splits, tokenizer=tok)
return train_env


def make_env_and_models(args, train_vocab_path, train_splits, test_splits,
batch_size=BATCH_SIZE):
setup()
Expand Down Expand Up @@ -215,10 +224,19 @@ def train_setup(args, batch_size=BATCH_SIZE):

train_env, val_envs, encoder, decoder = make_env_and_models(
args, vocab, train_splits, val_splits, batch_size=batch_size)
if args.use_pretraining:
pretrain_splits = args.pretrain_splits
assert len(pretrain_splits) > 0, \
'must specify at least one pretrain split'
pretrain_env = make_more_train_env(
args, vocab, pretrain_splits, batch_size=batch_size)
else:
pretrain_env = None

agent = Seq2SeqAgent(
train_env, "", encoder, decoder, max_episode_len,
max_instruction_length=MAX_INPUT_LENGTH)
return agent, train_env, val_envs
return agent, train_env, val_envs, pretrain_env


# Test set prediction will be handled separately
Expand All @@ -234,9 +252,21 @@ def train_setup(args, batch_size=BATCH_SIZE):

def train_val(args):
''' Train on the training set, and validate on seen and unseen splits. '''
agent, train_env, val_envs = train_setup(args)
train(args, train_env, agent, val_envs=val_envs)
agent, train_env, val_envs, pretrain_env = train_setup(args)

encoder_optimizer = optim.Adam(
filter_param(agent.encoder.parameters()), lr=learning_rate,
weight_decay=weight_decay)
decoder_optimizer = optim.Adam(
filter_param(agent.decoder.parameters()), lr=learning_rate,
weight_decay=weight_decay)

if pretrain_env:
train(args, pretrain_env, agent, encoder_optimizer, decoder_optimizer,
args.n_pretrain_iters, val_envs=val_envs)

train(args, train_env, agent, encoder_optimizer, decoder_optimizer,
args.n_iters, val_envs=val_envs)

# Test set prediction will be handled separately
# def test_submission(args):
Expand Down Expand Up @@ -264,6 +294,9 @@ def make_arg_parser():
default="sample")
parser.add_argument("--bidirectional", action='store_true')
parser.add_argument("--n_iters", type=int, default=20000)
parser.add_argument("--use_pretraining", action='store_true')
parser.add_argument("--pretrain_splits", nargs="+", default=[])
parser.add_argument("--n_pretrain_iters", type=int, default=50000)
parser.add_argument("--no_save", action='store_true')
parser.add_argument(
"--use_train_subset", action='store_true',
Expand Down

0 comments on commit 1506b29

Please sign in to comment.