Skip to content

Commit

Permalink
avoid NaN in entropy computation
Browse files Browse the repository at this point in the history
  • Loading branch information
ronghanghu committed May 29, 2018
1 parent f7a6e69 commit 970c766
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion models_clevr/nmn3_netgen_att.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def loop_fn(time, cell_output, cell_state, loop_state):
token_prob = tf.reduce_sum(all_token_probs * tf.cast(mask, tf.float32), axis=1)
# tf.assert_positive(token_prob)
neg_entropy = tf.reduce_sum(
all_token_probs * tf.log(all_token_probs + (1-validity_mult)),
all_token_probs * tf.log(tf.maximum(1e-5, all_token_probs + (1-validity_mult))),
axis=1)

# update states
Expand Down
2 changes: 1 addition & 1 deletion models_shapes/nmn3_netgen_att.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def loop_fn(time, cell_output, cell_state, loop_state):
token_prob = tf.reduce_sum(all_token_probs *
tf.cast(mask, tf.float32), axis=1)
neg_entropy = tf.reduce_sum(all_token_probs *
tf.log(all_token_probs), axis=1)
tf.log(tf.maximum(1e-5, all_token_probs)), axis=1)

# is_eos_predicted is a [N] bool tensor, indicating whether
# <eos> has already been predicted previously in each sequence
Expand Down
2 changes: 1 addition & 1 deletion models_vqa/nmn3_netgen_att.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def loop_fn(time, cell_output, cell_state, loop_state):
token_prob = tf.reduce_sum(all_token_probs * tf.cast(mask, tf.float32), axis=1)
# tf.assert_positive(token_prob)
neg_entropy = tf.reduce_sum(
all_token_probs * tf.log(all_token_probs + (1-validity_mult)),
all_token_probs * tf.log(tf.maximum(1e-5, all_token_probs + (1-validity_mult))),
axis=1)

# update states
Expand Down

0 comments on commit 970c766

Please sign in to comment.