Skip to content

Commit

Permalink
Returning eval_on_train_input_fn from create_estimator_and_inputs(), …
Browse files Browse the repository at this point in the history
…rather than using train_input_fn in EVAL mode (which will still have data augmentation).

PiperOrigin-RevId: 192320460
  • Loading branch information
pkulzc committed Apr 13, 2018
1 parent 7e81000 commit 227f41e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 9 deletions.
12 changes: 11 additions & 1 deletion research/object_detection/model_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def create_estimator_and_inputs(run_config,
'estimator': An `Estimator` or `TPUEstimator`.
'train_input_fn': A training input function.
'eval_input_fn': An evaluation input function.
'eval_on_train_input_fn': An evaluation-on-train input function.
'predict_input_fn': A prediction input function.
'train_steps': Number of training steps. Either directly from input or from
configuration.
Expand Down Expand Up @@ -484,6 +485,10 @@ def create_estimator_and_inputs(run_config,
eval_config=eval_config,
eval_input_config=eval_input_config,
model_config=model_config)
eval_on_train_input_fn = create_eval_input_fn(
eval_config=eval_config,
eval_input_config=train_input_config,
model_config=model_config)
predict_input_fn = create_predict_input_fn(model_config=model_config)

model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu)
Expand All @@ -509,13 +514,15 @@ def create_estimator_and_inputs(run_config,
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
eval_on_train_input_fn=eval_on_train_input_fn,
predict_input_fn=predict_input_fn,
train_steps=train_steps,
eval_steps=eval_steps)


def create_train_and_eval_specs(train_input_fn,
eval_input_fn,
eval_on_train_input_fn,
predict_input_fn,
train_steps,
eval_steps,
Expand All @@ -527,6 +534,8 @@ def create_train_and_eval_specs(train_input_fn,
Args:
train_input_fn: Function that produces features and labels on train data.
eval_input_fn: Function that produces features and labels on eval data.
eval_on_train_input_fn: Function that produces features and labels for
evaluation on train data.
predict_input_fn: Function that produces features for inference.
train_steps: Number of training steps.
eval_steps: Number of eval steps.
Expand Down Expand Up @@ -558,7 +567,8 @@ def create_train_and_eval_specs(train_input_fn,
if eval_on_train_data:
eval_specs.append(
tf.estimator.EvalSpec(
name='eval_on_train', input_fn=train_input_fn, steps=eval_steps))
name='eval_on_train', input_fn=eval_on_train_input_fn,
steps=eval_steps))

return train_spec, eval_specs

Expand Down
32 changes: 25 additions & 7 deletions research/object_detection/model_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,26 @@ def _assert_model_fn_for_train_eval(self, configs, mode,
model_config = configs['model']
train_config = configs['train_config']
with tf.Graph().as_default():
if mode == tf.estimator.ModeKeys.TRAIN:
if mode == 'train':
features, labels = inputs.create_train_input_fn(
configs['train_config'],
configs['train_input_config'],
configs['model'])()
model_mode = tf.estimator.ModeKeys.TRAIN
batch_size = train_config.batch_size
else:
elif mode == 'eval':
features, labels = inputs.create_eval_input_fn(
configs['eval_config'],
configs['eval_input_config'],
configs['model'])()
model_mode = tf.estimator.ModeKeys.EVAL
batch_size = 1
elif mode == 'eval_on_train':
features, labels = inputs.create_eval_input_fn(
configs['eval_config'],
configs['train_input_config'],
configs['model'])()
model_mode = tf.estimator.ModeKeys.EVAL
batch_size = 1

detection_model_fn = functools.partial(
Expand All @@ -103,7 +112,7 @@ def _assert_model_fn_for_train_eval(self, configs, mode,
hparams_overrides='load_pretrained=false')

model_fn = model_lib.create_model_fn(detection_model_fn, configs, hparams)
estimator_spec = model_fn(features, labels, mode)
estimator_spec = model_fn(features, labels, model_mode)

self.assertIsNotNone(estimator_spec.loss)
self.assertIsNotNone(estimator_spec.predictions)
Expand All @@ -121,7 +130,7 @@ def _assert_model_fn_for_train_eval(self, configs, mode,
self.assertEqual(batch_size, detection_scores.shape.as_list()[0])
self.assertEqual(tf.float32, detection_scores.dtype)
self.assertEqual(tf.float32, num_detections.dtype)
if mode == tf.estimator.ModeKeys.TRAIN:
if model_mode == tf.estimator.ModeKeys.TRAIN:
self.assertIsNotNone(estimator_spec.train_op)
return estimator_spec

Expand Down Expand Up @@ -152,12 +161,17 @@ def _assert_model_fn_for_predict(self, configs):
def test_model_fn_in_train_mode(self):
"""Tests the model function in TRAIN mode."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
self._assert_model_fn_for_train_eval(configs, tf.estimator.ModeKeys.TRAIN)
self._assert_model_fn_for_train_eval(configs, 'train')

def test_model_fn_in_eval_mode(self):
"""Tests the model function in EVAL mode."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
self._assert_model_fn_for_train_eval(configs, tf.estimator.ModeKeys.EVAL)
self._assert_model_fn_for_train_eval(configs, 'eval')

def test_model_fn_in_eval_on_train_mode(self):
"""Tests the model function in EVAL mode with train data."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
self._assert_model_fn_for_train_eval(configs, 'eval_on_train')

def test_model_fn_in_predict_mode(self):
"""Tests the model function in PREDICT mode."""
Expand All @@ -181,10 +195,12 @@ def test_create_estimator_and_inputs(self):
estimator = train_and_eval_dict['estimator']
train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps']

self.assertIsInstance(estimator, tf.estimator.Estimator)
self.assertEqual(20, train_steps)
self.assertEqual(10, eval_steps)
self.assertIn('train_input_fn', train_and_eval_dict)
self.assertIn('eval_input_fn', train_and_eval_dict)
self.assertIn('eval_on_train_input_fn', train_and_eval_dict)

def test_create_estimator_with_default_train_eval_steps(self):
"""Tests that number of train/eval defaults to config values."""
Expand Down Expand Up @@ -245,13 +261,15 @@ def test_create_train_and_eval_specs(self):
eval_steps=eval_steps)
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fn = train_and_eval_dict['eval_input_fn']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps']

train_spec, eval_specs = model_lib.create_train_and_eval_specs(
train_input_fn,
eval_input_fn,
eval_on_train_input_fn,
predict_input_fn,
train_steps,
eval_steps,
Expand Down
2 changes: 2 additions & 0 deletions research/object_detection/model_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ def main(unused_argv):
estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fn = train_and_eval_dict['eval_input_fn']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps']

train_spec, eval_specs = model_lib.create_train_and_eval_specs(
train_input_fn,
eval_input_fn,
eval_on_train_input_fn,
predict_input_fn,
train_steps,
eval_steps,
Expand Down
3 changes: 2 additions & 1 deletion research/object_detection/model_tpu_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def main(unused_argv):
estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fn = train_and_eval_dict['eval_input_fn']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps']

Expand Down Expand Up @@ -158,7 +159,7 @@ def terminate_eval():
tf.logging.info('Starting to evaluate.')
if FLAGS.eval_training_data:
name = 'training_data'
input_fn = train_input_fn
input_fn = eval_on_train_input_fn
else:
name = 'validation_data'
input_fn = eval_input_fn
Expand Down

0 comments on commit 227f41e

Please sign in to comment.