Skip to content

Commit

Permalink
BUG: typo with LSTM cell
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Apr 8, 2017
1 parent 9b7028a commit a214201
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions network.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def init_inference(self, config):
stride = layer['stride']
bn = layer.get('enable_batch_norm', None)
ln = layer.get('enable_layer_norm', None)

if bn is not None or ln is not None:
acts = tf.contrib.layers.convolution2d(acts, num_outputs=num_filters,
kernel_size=[filter_size, 1],
Expand All @@ -43,16 +43,16 @@ def init_inference(self, config):

if bn == True:
logger.debug("Adding Batch Norm Layer")
acts = tf.contrib.layers.batch_norm(acts, decay=0.9, center=True,
scale=True, epsilon=1e-8,
activation_fn=tf.nn.relu,
acts = tf.contrib.layers.batch_norm(acts, decay=0.9, center=True,
scale=True, epsilon=1e-8,
activation_fn=tf.nn.relu,
is_training=True)

elif ln == True:
logger.debug("Adding Layer Norm Layer")
acts = tf.contrib.layers.layer_norm(acts, center=True,
scale=True,
activation_fn=tf.nn.relu)
acts = tf.contrib.layers.layer_norm(acts, center=True,
scale=True,
activation_fn=tf.nn.relu)
else:
assert True, "Batch or Layer norm must be specified as True"
else:
Expand Down Expand Up @@ -190,7 +190,7 @@ def _rnn(acts, input_dim, cell_type, scope=None):
cell = tf.contrib.rnn.GRUCell(input_dim)
elif cell_type == 'lstm':
logger.info("Adding cell type " + cell_type + " to rnn")
cell = tf.contrib.LSTMCell(input_dim)
cell = tf.contrib.rnn.LSTMCell(input_dim)
else:
msg = "Invalid cell type {}".format(cell_type)
raise ValueError(msg)
Expand Down

0 comments on commit a214201

Please sign in to comment.