Skip to content

Commit

Permalink
Enable name based definition of keras initializers in hyperparams.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 365913722
  • Loading branch information
Austin Myers authored and TF Object Detection Team committed Mar 30, 2021
1 parent f55a0eb commit 718b307
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
9 changes: 8 additions & 1 deletion research/object_detection/builders/hyperparams_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _build_initializer(initializer, build_for_keras=False):
operators. If false builds for Slim.
Returns:
tf initializer.
tf initializer or string corresponding to the tf keras initializer name.
Raises:
ValueError: On unknown initializer.
Expand Down Expand Up @@ -415,6 +415,13 @@ def _build_initializer(initializer, build_for_keras=False):
factor=initializer.variance_scaling_initializer.factor,
mode=mode,
uniform=initializer.variance_scaling_initializer.uniform)
if initializer_oneof == 'keras_initializer_by_name':
if build_for_keras:
return initializer.keras_initializer_by_name
else:
raise ValueError(
'Unsupported non-Keras usage of keras_initializer_by_name: {}'.format(
initializer.keras_initializer_by_name))
if initializer_oneof is None:
return None
raise ValueError('Unknown initializer function: {}'.format(
Expand Down
21 changes: 21 additions & 0 deletions research/object_detection/builders/hyperparams_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,5 +1030,26 @@ def test_variance_in_range_with_random_normal_initializer_keras(self):
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.64, tol=1e-1)

def test_keras_initializer_by_name(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
keras_initializer_by_name: "glorot_uniform"
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Parse(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
initializer_arg = keras_config.params()['kernel_initializer']
conv_layer = tf.keras.layers.Conv2D(
filters=16, kernel_size=3, **keras_config.params())
self.assertEqual(initializer_arg, 'glorot_uniform')
self.assertIsInstance(conv_layer.kernel_initializer,
type(tf.keras.initializers.get('glorot_uniform')))

if __name__ == '__main__':
tf.test.main()
5 changes: 5 additions & 0 deletions research/object_detection/protos/hyperparams.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ message Initializer {
TruncatedNormalInitializer truncated_normal_initializer = 1;
VarianceScalingInitializer variance_scaling_initializer = 2;
RandomNormalInitializer random_normal_initializer = 3;
// Allows specifying initializers by name, as a string, which will be passed
// directly as an argument during layer construction. Currently, this is
// only supported when using KerasLayerHyperparams, and for valid Keras
// initializers, e.g. `glorot_uniform`, `variance_scaling`, etc.
string keras_initializer_by_name = 4;
}
}

Expand Down

0 comments on commit 718b307

Please sign in to comment.