From a21420190b9945f74a8e8c040d999eb21d933212 Mon Sep 17 00:00:00 2001 From: awni Date: Fri, 7 Apr 2017 17:05:33 -0700 Subject: [PATCH] BUG: typo with LSTM cell --- network.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/network.py b/network.py index 0222cff..705de08 100644 --- a/network.py +++ b/network.py @@ -32,7 +32,7 @@ def init_inference(self, config): stride = layer['stride'] bn = layer.get('enable_batch_norm', None) ln = layer.get('enable_layer_norm', None) - + if bn is not None or ln is not None: acts = tf.contrib.layers.convolution2d(acts, num_outputs=num_filters, kernel_size=[filter_size, 1], @@ -43,16 +43,16 @@ def init_inference(self, config): if bn == True: logger.debug("Adding Batch Norm Layer") - acts = tf.contrib.layers.batch_norm(acts, decay=0.9, center=True, - scale=True, epsilon=1e-8, - activation_fn=tf.nn.relu, + acts = tf.contrib.layers.batch_norm(acts, decay=0.9, center=True, + scale=True, epsilon=1e-8, + activation_fn=tf.nn.relu, is_training=True) - + elif ln == True: logger.debug("Adding Layer Norm Layer") - acts = tf.contrib.layers.layer_norm(acts, center=True, - scale=True, - activation_fn=tf.nn.relu) + acts = tf.contrib.layers.layer_norm(acts, center=True, + scale=True, + activation_fn=tf.nn.relu) else: assert True, "Batch or Layer norm must be specified as True" else: @@ -190,7 +190,7 @@ def _rnn(acts, input_dim, cell_type, scope=None): cell = tf.contrib.rnn.GRUCell(input_dim) elif cell_type == 'lstm': logger.info("Adding cell type " + cell_type + " to rnn") - cell = tf.contrib.LSTMCell(input_dim) + cell = tf.contrib.rnn.LSTMCell(input_dim) else: msg = "Invalid cell type {}".format(cell_type) raise ValueError(msg)