Skip to content

Commit

Permalink
Reuse code create_hp_and_estimator().
Browse files Browse the repository at this point in the history
  • Loading branch information
thtrieu committed Aug 16, 2019
1 parent 4ceb2cb commit 15bb117
Showing 1 changed file with 5 additions and 25 deletions.
30 changes: 5 additions & 25 deletions decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
FLAGS = flags.FLAGS


def create_hp_and_estimator(problem_name, data_dir, checkpoint_path):
def create_hp_and_estimator(
problem_name, data_dir, checkpoint_path, decode_to_file=None):
trainer_lib.set_random_seed(FLAGS.random_seed)

hp = trainer_lib.create_hparams(
Expand All @@ -43,7 +44,7 @@ def create_hp_and_estimator(problem_name, data_dir, checkpoint_path):
decode_hp.shard_id = FLAGS.worker_id
decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
decode_hp.decode_in_memory = decode_in_memory
decode_hp.decode_to_file = None
decode_hp.decode_to_file = decode_to_file
decode_hp.decode_reference = None

FLAGS.checkpoint_path = checkpoint_path
Expand Down Expand Up @@ -364,29 +365,8 @@ def timer(gen):
def t2t_decoder(problem_name, data_dir,
decode_from_file, decode_to_file,
checkpoint_path):
trainer_lib.set_random_seed(FLAGS.random_seed)

hp = trainer_lib.create_hparams(
FLAGS.hparams_set,
FLAGS.hparams,
data_dir=os.path.expanduser(data_dir),
problem_name=problem_name)

decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
decode_hp.shards = FLAGS.decode_shards
decode_hp.shard_id = FLAGS.worker_id
decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
decode_hp.decode_in_memory = decode_in_memory
decode_hp.decode_to_file = decode_to_file
decode_hp.decode_reference = None

FLAGS.checkpoint_path = checkpoint_path
estimator = trainer_lib.create_estimator(
FLAGS.model,
hp,
t2t_trainer.create_run_config(hp),
decode_hparams=decode_hp,
use_tpu=FLAGS.use_tpu)
hp, decode_hp, estimator = create_hp_and_estimator(
problem_name, data_dir, checkpoint_path, decode_to_file)

decode_from_text_file(
estimator, problem_name,
Expand Down

0 comments on commit 15bb117

Please sign in to comment.