Skip to content

Commit

Permalink
Generalize learning.train for multiple sync optimizers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 159595373
  • Loading branch information
tensorflower-gardener committed Jun 20, 2017
1 parent 74cf446 commit 716687c
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions tensorflow/contrib/slim/python/slim/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,9 @@ def train(train_op,
saver: Saver to save checkpoints. If None, a default one will be created
and used.
save_interval_secs: How often, in seconds, to save the model to `logdir`.
sync_optimizer: an instance of tf.train.SyncReplicasOptimizer. If the
argument is supplied, gradient updates will be synchronous. If left as
`None`, gradient updates will be asynchronous.
sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of
them. If the argument is supplied, gradient updates will be synchronous.
If left as `None`, gradient updates will be asynchronous.
session_config: An instance of `tf.ConfigProto` that will be used to
configure the `Session`. If left as `None`, the default will be used.
trace_every_n_steps: produce and save a `Timeline` in Chrome trace format
Expand Down Expand Up @@ -633,6 +633,8 @@ def train(train_op,
raise ValueError('Cannot provide trace_every_n_steps because '
'logdir=None')

if isinstance(sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
sync_optimizer = [sync_optimizer]
if sync_optimizer is not None and startup_delay_steps > 0:
raise ValueError(
'startup_delay_steps must be zero when sync_optimizer is supplied.')
Expand All @@ -647,6 +649,12 @@ def train(train_op,
global_step = variables.get_or_create_global_step()
saver = saver or tf_saver.Saver()

if sync_optimizer is not None:
for opt in sync_optimizer:
if not isinstance(opt, sync_replicas_optimizer.SyncReplicasOptimizer):
raise ValueError(
'`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.')

with ops.name_scope('init_ops'):
if init_op == _USE_DEFAULT:
init_op = tf_variables.global_variables_initializer()
Expand All @@ -659,15 +667,17 @@ def train(train_op,
tf_variables.local_variables_initializer(),
lookup_ops.tables_initializer())

if sync_optimizer is not None and isinstance(
sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
if sync_optimizer is not None and isinstance(sync_optimizer, list):
with ops.control_dependencies([local_init_op] if local_init_op is
not None else []):
if is_chief:
local_init_op = sync_optimizer.chief_init_op
local_init_op = control_flow_ops.group(
*[opt.chief_init_op for opt in sync_optimizer])
else:
local_init_op = sync_optimizer.local_step_init_op
ready_for_local_init_op = sync_optimizer.ready_for_local_init_op
local_init_op = control_flow_ops.group(
*[opt.local_step_init_op for opt in sync_optimizer])
ready_for_local_init_op = control_flow_ops.group(
*[opt.ready_for_local_init_op for opt in sync_optimizer])
else:
ready_for_local_init_op = None

Expand All @@ -678,14 +688,10 @@ def train(train_op,
summary_writer = supervisor.Supervisor.USE_DEFAULT

if is_chief and sync_optimizer is not None:
if not isinstance(sync_optimizer,
(sync_replicas_optimizer.SyncReplicasOptimizer)):
raise ValueError(
'`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.')

# Need to create these BEFORE the supervisor finalizes the graph:
init_tokens_op = sync_optimizer.get_init_tokens_op()
chief_queue_runner = sync_optimizer.get_chief_queue_runner()
init_tokens_op = [opt.get_init_tokens_op() for opt in sync_optimizer]
chief_queue_runner = [
opt.get_chief_queue_runner() for opt in sync_optimizer]

if train_step_kwargs == _USE_DEFAULT:
with ops.name_scope('train_step'):
Expand Down Expand Up @@ -741,7 +747,7 @@ def train(train_op,
threads = sv.start_queue_runners(sess)
logging.info('Starting Queues.')
if is_chief and sync_optimizer is not None:
sv.start_queue_runners(sess, [chief_queue_runner])
sv.start_queue_runners(sess, chief_queue_runner)
sess.run(init_tokens_op)
try:
while not sv.should_stop():
Expand Down

0 comments on commit 716687c

Please sign in to comment.