Skip to content

Commit

Permalink
Propogate keep_checkpoint_every_n_hours to the default Saver inside e…
Browse files Browse the repository at this point in the history
…stimator.

PiperOrigin-RevId: 159138894
  • Loading branch information
ispirmustafa authored and tensorflower-gardener committed Jun 15, 2017
1 parent 1588d37 commit fa92763
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions tensorflow/python/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,12 +603,15 @@ def _train_model(self, input_fn, hooks):

if not (estimator_spec.scaffold.saver or
ops.get_collection(ops.GraphKeys.SAVERS)):
ops.add_to_collection(ops.GraphKeys.SAVERS,
training.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
defer_build=True,
save_relative_paths=True))
ops.add_to_collection(
ops.GraphKeys.SAVERS,
training.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
keep_checkpoint_every_n_hours=(
self._config.keep_checkpoint_every_n_hours),
defer_build=True,
save_relative_paths=True))

chief_hooks = []
if (self._config.save_checkpoints_secs or
Expand Down Expand Up @@ -862,7 +865,8 @@ def _write_dict_to_summary(output_dir,
value.simple_value = int(dictionary[key])
else:
logging.warn(
'Skipping summary for %s, must be a float, np.float32, np.int64, np.int32 or int.',
'Skipping summary for %s, must be a float, np.float32, np.int64, '
'np.int32 or int.',
key)
summary_writer.add_summary(summary_proto, current_global_step)
summary_writer.flush()

0 comments on commit fa92763

Please sign in to comment.