Skip to content

Commit

Permalink
Test logic to extract and summarize attention weights.
Browse files Browse the repository at this point in the history
  • Loading branch information
mshseek committed Jan 23, 2019
1 parent 482e441 commit 006f978
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions bert_attn_viz/tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
import bert_attn_viz.explain.attention as attn
import tensorflow as tf

tf.enable_eager_execution()


@pytest.fixture
def sample_attention_weights():
def create_dummy_weights():
x = tf.random.normal((3, 8, 10, 10))
x = tf.nn.softmax(x, axis=-1)

return x

return [
{'layer_1': create_dummy_weights()},
{'layer_2': create_dummy_weights()},
{'layer_3': create_dummy_weights()}
]


def test_average_first_layer_by_head(sample_attention_weights):
x = attn.average_first_layer_by_head(sample_attention_weights)
y = attn.average_layer_i_on_token_j_by_head(0, 0, sample_attention_weights)

assert tf.reduce_all(tf.equal(x, y))


def test_average_last_layer_by_head(sample_attention_weights):
x = attn.average_last_layer_by_head(sample_attention_weights)
y = attn.average_layer_i_on_token_j_by_head(-1, 0, sample_attention_weights)

assert tf.reduce_all(tf.equal(x, y))

0 comments on commit 006f978

Please sign in to comment.