Skip to content

Commit

Permalink
Add 'minival' split for early stopping.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 322246032
  • Loading branch information
mingxingtan authored and allenwang28 committed Jul 27, 2020
1 parent e4350c4 commit e7ef320
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 20 deletions.
10 changes: 7 additions & 3 deletions models/official/efficientnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ EfficientNets achieve state-of-the-art accuracy on ImageNet with an order of mag

## 2. Using Pretrained EfficientNet Checkpoints

We have provided a list of EfficientNet checkpoints for EfficientNet checkpoints:.
To train EfficientNet on ImageNet, we hold out 25,022 randomly picked images ([image filenames](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/eval_data/val_split20.txt), or 20 out of 1024 total shards) as a 'minival' split, and conduct early stopping based on this 'minival' split. The final accuracy is reported on the original ImageNet validation set.

We have provided a list of EfficientNet checkpoints:.

* With baseline ResNet preprocessing, we achieve similar results to the original ICML paper.
* With [AutoAugment](https://arxiv.org/abs/1805.09501) preprocessing, we achieve higher accuracy than the original ICML paper.
Expand All @@ -59,8 +61,8 @@ We have provided a list of EfficientNet checkpoints for EfficientNet checkpoints

| | B0 | B1 | B2 | B3 | B4 | B5 | B6 | B7 | B8 | L2-475 | L2 |
|---------- |-------- | ------| ------|------ |------ |------ | --- | --- | --- | --- |--- |
| Baseline preprocessing | 76.8% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b0.tar.gz)) | 78.8% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b1.tar.gz)) | 79.8% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b2.tar.gz)) | 81.0% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b3.tar.gz)) | 82.6% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b4.tar.gz)) | 83.2% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b5.tar.gz)) | | || | | |
| AutoAugment (AA) | 77.3% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b0.tar.gz)) | 79.2% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b1.tar.gz)) | 80.3% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b2.tar.gz)) | 81.7% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b3.tar.gz)) | 83.0% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b4.tar.gz)) | 83.7% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b5.tar.gz)) | 84.2% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b6.tar.gz)) | 84.5% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b7.tar.gz)) || | |
| Baseline preprocessing | 76.7% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b0.tar.gz)) | 78.7% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b1.tar.gz)) | 79.8% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b2.tar.gz)) | 81.1% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b3.tar.gz)) | 82.5% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b4.tar.gz)) | 83.1% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b5.tar.gz)) | | || | | |
| AutoAugment (AA) | 77.1% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b0.tar.gz)) | 79.1% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b1.tar.gz)) | 80.1% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b2.tar.gz)) | 81.6% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b3.tar.gz)) | 82.9% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b4.tar.gz)) | 83.6% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b5.tar.gz)) | 84.0% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b6.tar.gz)) | 84.3% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b7.tar.gz)) || | |
| RandAugment (RA) | | | | | | 83.9% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/randaug/efficientnet-b5-randaug.tar.gz)) | | 85.0% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/randaug/efficientnet-b7-randaug.tar.gz)) | 85.4% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/randaug/efficientnet-b8-randaug.tar.gz)) | | |
| AdvProp + AA | 77.6% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b0.tar.gz)) | 79.6% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b1.tar.gz)) | 80.5% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b2.tar.gz)) | 81.9% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b3.tar.gz)) | 83.3% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b4.tar.gz)) | 84.3% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b5.tar.gz)) | 84.8% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b6.tar.gz)) | 85.2% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b7.tar.gz)) | 85.5% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b8.tar.gz))|| | |
| NoisyStudent + RA | 78.8% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b0.tar.gz)) | 81.5% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b1.tar.gz)) | 82.4% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b2.tar.gz)) | 84.1% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b3.tar.gz)) | 85.3% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b4.tar.gz)) | 86.1% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b5.tar.gz)) | 86.4% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b6.tar.gz)) | 86.9% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b7.tar.gz)) | - |88.2%([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-l2_475.tar.gz))|88.4% ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-l2.tar.gz)) |
Expand All @@ -77,6 +79,8 @@ We have provided a list of EfficientNet checkpoints for EfficientNet checkpoints

<sup>* NoisyStudent training code coming soon. L2-475 means the same L2 architecture with input image size 475 (Please set "--input_image_size=475" for using this checkpoint). If you use NoisyStudent checkpoints, you can cite this [paper](https://arxiv.org/abs/1911.04252).</sup>

<sup>*Note that RangAug, AdvProp, NoisyStudent performance is derived from baselines that don't use holdout eval set. They will be updated in future."</sup>

A quick way to use these checkpoints is to run:

$ export MODEL=efficientnet-b0
Expand Down
28 changes: 23 additions & 5 deletions models/official/efficientnet/imagenet_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def __init__(self,
mixup_alpha=0.0,
randaug_num_layers=None,
randaug_magnitude=None,
resize_method=None):
resize_method=None,
holdout_shards=None):
"""Create an input from TFRecord files.
Args:
Expand Down Expand Up @@ -331,6 +332,7 @@ def __init__(self,
randaug_magnitude: 'int', if RandAug is used, what should the magnitude
be. See autoaugment.py for detailed description.
resize_method: If None, use bicubic in default.
holdout_shards: number of holdout training shards for validation.
"""
super(ImageNetInput, self).__init__(
is_training=is_training,
Expand All @@ -348,6 +350,7 @@ def __init__(self,
self.data_dir = None
self.num_parallel_calls = num_parallel_calls
self.cache = cache
self.holdout_shards = holdout_shards

def _get_null_input(self, data):
"""Returns a null image (all black pixels).
Expand Down Expand Up @@ -375,14 +378,29 @@ def make_source_dataset(self, index, num_hosts):
logging.info('Undefined data_dir implies null input')
return tf.data.Dataset.range(1).repeat().map(self._get_null_input)

# Shuffle the filenames to ensure better randomization.
file_pattern = os.path.join(
self.data_dir, 'train-*' if self.is_training else 'validation-*')
if self.holdout_shards:
if self.is_training:
filenames = [
os.path.join(self.data_dir, 'train-%05d-of-01024' % i)
for i in range(self.holdout_shards, 1024)
]
else:
filenames = [
os.path.join(self.data_dir, 'train-%05d-of-01024' % i)
for i in range(0, self.holdout_shards)
]
for f in filenames[:10]:
logging.info('datafiles: %s', f)
dataset = tf.data.Dataset.from_tensor_slices(filenames)
else:
file_pattern = os.path.join(
self.data_dir, 'train-*' if self.is_training else 'validation-*')
logging.info('datafiles: %s', file_pattern)
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)

# For multi-host training, we want each hosts to always process the same
# subset of files. Each host only sees a subset of the entire dataset,
# allowing us to cache larger datasets in memory.
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)
dataset = dataset.shard(num_hosts, index)

if self.is_training and not self.cache:
Expand Down
50 changes: 38 additions & 12 deletions models/official/efficientnet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@
help=('The directory where the ImageNet input data is stored. Please see'
' the README.md for the expected data format.'))

flags.DEFINE_integer(
'holdout_shards',
default=None,
help=('Number of holdout shards for validation. 0 means no holdout.'))

flags.DEFINE_string('eval_name', default=None, help=('Evaluation name.'))

flags.DEFINE_bool(
'archive_ckpt', default=True, help=('If true, archive the best ckpt.'))

flags.DEFINE_string(
'model_dir', default=None,
help=('The directory where the model and training/evaluation summaries are'
Expand Down Expand Up @@ -115,7 +125,7 @@
'train_batch_size', default=2048, help='Batch size for training.')

flags.DEFINE_integer(
'eval_batch_size', default=1024, help='Batch size for evaluation.')
'eval_batch_size', default=64, help='Batch size for evaluation.')

flags.DEFINE_integer(
'num_train_images', default=1281167, help='Size of training data set.')
Expand Down Expand Up @@ -218,9 +228,7 @@
default=0.016,
help=('Base learning rate when train batch size is 256.'))

flags.DEFINE_float(
'momentum', default=0.9,
help=('Momentum parameter used in the MomentumOptimizer.'))
flags.DEFINE_float('lr_decay_epoch', default=2.4, help='LR decay epoch.')

flags.DEFINE_float(
'moving_average_decay', default=0.9999,
Expand Down Expand Up @@ -386,8 +394,11 @@ def build_model():

scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
logging.info('base_learning_rate = %f', FLAGS.base_learning_rate)
learning_rate = utils.build_learning_rate(scaled_lr, global_step,
params['steps_per_epoch'])
learning_rate = utils.build_learning_rate(
scaled_lr,
global_step,
params['steps_per_epoch'],
decay_epochs=FLAGS.lr_decay_epoch)
optimizer = utils.build_optimizer(learning_rate)
if FLAGS.use_tpu:
# When using TPU, wrap the optimizer with CrossShardOptimizer which
Expand Down Expand Up @@ -605,6 +616,11 @@ def main(unused_argv):
input_image_size = model_builder_factory.get_model_input_size(
FLAGS.model_name)

if FLAGS.holdout_shards:
holdout_images = int(FLAGS.num_train_images * FLAGS.holdout_shards / 1024.0)
FLAGS.num_train_images -= holdout_images
FLAGS.num_eval_images = holdout_images

# For imagenet dataset, include background label if number of output classes
# is 1001
include_background_label = (FLAGS.num_label_classes == 1001)
Expand Down Expand Up @@ -692,7 +708,8 @@ def build_imagenet_input(is_training):
mixup_alpha=FLAGS.mixup_alpha,
randaug_num_layers=FLAGS.randaug_num_layers,
randaug_magnitude=FLAGS.randaug_magnitude,
resize_method=resize_method)
resize_method=resize_method,
holdout_shards=FLAGS.holdout_shards)

imagenet_train = build_imagenet_input(is_training=True)
imagenet_eval = build_imagenet_input(is_training=False)
Expand All @@ -708,14 +725,21 @@ def build_imagenet_input(is_training):
eval_results = est.evaluate(
input_fn=imagenet_eval.input_fn,
steps=eval_steps,
checkpoint_path=ckpt)
checkpoint_path=ckpt,
name=FLAGS.eval_name)
elapsed_time = int(time.time() - start_timestamp)
logging.info('Eval results: %s. Elapsed seconds: %d',
eval_results, elapsed_time)
utils.archive_ckpt(eval_results, eval_results['top_1_accuracy'], ckpt)
if FLAGS.archive_ckpt:
utils.archive_ckpt(eval_results, eval_results['top_1_accuracy'], ckpt)

# Terminate eval job when final checkpoint is reached
current_step = int(os.path.basename(ckpt).split('-')[1])
try:
current_step = int(os.path.basename(ckpt).split('-')[1])
except IndexError:
logging.info('%s has no global step info: stop!', ckpt)
break

if current_step >= FLAGS.train_steps:
logging.info(
'Evaluation finished after training step %d', current_step)
Expand Down Expand Up @@ -777,11 +801,13 @@ def build_imagenet_input(is_training):
logging.info('Starting to evaluate.')
eval_results = est.evaluate(
input_fn=imagenet_eval.input_fn,
steps=FLAGS.num_eval_images // FLAGS.eval_batch_size)
steps=FLAGS.num_eval_images // FLAGS.eval_batch_size,
name=FLAGS.eval_name)
logging.info('Eval results at step %d: %s',
next_checkpoint, eval_results)
ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
utils.archive_ckpt(eval_results, eval_results['top_1_accuracy'], ckpt)
if FLAGS.archive_ckpt:
utils.archive_ckpt(eval_results, eval_results['top_1_accuracy'], ckpt)

elapsed_time = int(time.time() - start_timestamp)
logging.info('Finished training up to step %d. Elapsed seconds %d.',
Expand Down

0 comments on commit e7ef320

Please sign in to comment.