Skip to content

Commit

Permalink
fix selfplay data augmentation: add sampled paths
Browse files Browse the repository at this point in the history
  • Loading branch information
ronghanghu committed Sep 17, 2018
1 parent 569b977 commit ff43811
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 4 deletions.
1 change: 1 addition & 0 deletions tasks/R2R/data/download.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ wget https://storage.googleapis.com/bringmeaspoon/R2Rdata/R2R_train.json -P task
wget https://storage.googleapis.com/bringmeaspoon/R2Rdata/R2R_val_seen.json -P tasks/R2R/data/
wget https://storage.googleapis.com/bringmeaspoon/R2Rdata/R2R_val_unseen.json -P tasks/R2R/data/
wget https://storage.googleapis.com/bringmeaspoon/R2Rdata/R2R_test.json -P tasks/R2R/data/
wget http:https://people.eecs.berkeley.edu/~ronghang/projects/speaker_follower/data_augmentation/R2R_data_augmentation_paths.json -P tasks/R2R/data/
3 changes: 2 additions & 1 deletion tasks/R2R/selfplay_from_speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def make_arg_parser():
parser.add_argument("speaker_model_prefix")
parser.add_argument("pred_results_output_file")
parser.add_argument("--batch_size", type=int, default=20)
parser.add_argument("--pred_splits", nargs="+", default=["train_selfplay"])
parser.add_argument("--pred_splits", nargs="+",
default=["data_augmentation_paths"])

# for rational self-play generation
parser.add_argument("--follower_model_prefix",
Expand Down
3 changes: 0 additions & 3 deletions tasks/R2R/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def distance(pose1, pose2):
def load_datasets(splits):
data = []
for split in splits:
assert split in ['train', 'train_subset', 'val_seen', 'val_unseen', 'test',
'sub_train', 'sub_val_seen', 'sub_val_unseen', 'sub_train_subset',
'train_selfplay', 'train_selfplay_small']
with open('tasks/R2R/data/R2R_%s.json' % split) as f:
data += json.load(f)
return data
Expand Down

0 comments on commit ff43811

Please sign in to comment.