Skip to content

Commit

Permalink
forward compatible RNNs
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Apr 6, 2017
1 parent 1354e5c commit 449699d
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions network.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ def _zero_pad(inputs):
def _rnn(acts, input_dim, cell_type, scope=None):
if cell_type == 'gru':
logger.info("Adding cell type " + cell_type + " to rnn")
cell = tf.nn.rnn_cell.GRUCell(input_dim)
cell = tf.contrib.rnn.GRUCell(input_dim)
elif cell_type == 'lstm':
logger.info("Adding cell type " + cell_type + " to rnn")
cell = tf.nn.rnn_cell.LSTMCell(input_dim)
cell = tf.contrib.LSTMCell(input_dim)
else:
msg = "Invalid cell type {}".format(cell_type)
raise ValueError(msg)
Expand All @@ -185,11 +185,11 @@ def _bi_rnn(acts, input_dim, cell_type):

# Backwards
with tf.variable_scope("bw") as bw_scope:
reverse_dims = [False, True, False]
acts_bw = tf.reverse(acts, dims=reverse_dims)
reverse_dims = [1]
acts_bw = tf.reverse(acts, axis=reverse_dims)
acts_bw = _rnn(acts_bw, input_dim, cell_type,
scope=bw_scope)
acts_bw = tf.reverse(acts_bw, dims=reverse_dims)
acts_bw = tf.reverse(acts_bw, axis=reverse_dims)

# Sum the forward and backward states.
return tf.add(acts_fw, acts_bw)
Expand Down

0 comments on commit 449699d

Please sign in to comment.