-
Notifications
You must be signed in to change notification settings - Fork 45
/
attention.py
268 lines (218 loc) · 11.8 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# Copyright (c) 2022 The BayesFlow Developers
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import tensorflow as tf
from tensorflow.keras.layers import Dense, LayerNormalization, MultiHeadAttention
from tensorflow.keras.models import Sequential
class MultiHeadAttentionBlock(tf.keras.Model):
"""Implements the MAB block from [1] which represents learnable cross-attention.
[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).
Set transformer: A framework for attention-based permutation-invariant neural networks.
In International conference on machine learning (pp. 3744-3753). PMLR.
"""
def __init__(self, input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, **kwargs):
"""Creates a multihead attention block which will typically be used as part of a
set transformer architecture according to [1]. Corresponds to standard cross-attention.
Parameters
----------
input_dim : int
The dimensionality of the input data (last axis).
attention_settings : dict
A dictionary which will be unpacked as the arguments for the ``MultiHeadAttention`` layer
See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention.
num_dense_fc : int
The number of hidden layers for the internal feedforward network
dense_settings : dict
A dictionary which will be unpacked as the arguments for the ``Dense`` layer
use_layer_norm : boolean
Whether layer normalization before and after attention + feedforward
**kwargs : dict, optional, default: {}
Optional keyword arguments passed to the __init__() method of tf.keras.Model
"""
super().__init__(**kwargs)
self.att = MultiHeadAttention(**attention_settings)
self.ln_pre = LayerNormalization() if use_layer_norm else None
self.fc = Sequential([Dense(**dense_settings) for _ in range(num_dense_fc)])
self.fc.add(Dense(input_dim))
self.ln_post = LayerNormalization() if use_layer_norm else None
def call(self, x, y, **kwargs):
"""Performs the forward pass through the attention layer.
Parameters
----------
x : tf.Tensor
Input of shape (batch_size, set_size_x, input_dim)
y : tf.Tensor
Input of shape (batch_size, set_size_y, input_dim)
Returns
-------
out : tf.Tensor
Output of shape (batch_size, set_size_x, input_dim)
"""
h = x + self.att(x, y, y, **kwargs)
if self.ln_pre is not None:
h = self.ln_pre(h, **kwargs)
out = h + self.fc(h, **kwargs)
if self.ln_post is not None:
out = self.ln_post(out, **kwargs)
return out
class SelfAttentionBlock(tf.keras.Model):
"""Implements the SAB block from [1] which represents learnable self-attention.
[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).
Set transformer: A framework for attention-based permutation-invariant neural networks.
In International conference on machine learning (pp. 3744-3753). PMLR.
"""
def __init__(self, input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, **kwargs):
"""Creates a self-attention attention block which will typically be used as part of a
set transformer architecture according to [1].
Parameters
----------
input_dim : int
The dimensionality of the input data (last axis).
attention_settings : dict
A dictionary which will be unpacked as the arguments for the ``MultiHeadAttention`` layer
See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention.
num_dense_fc : int
The number of hidden layers for the internal feedforward network
dense_settings : dict
A dictionary which will be unpacked as the arguments for the ``Dense`` layer
use_layer_norm : boolean
Whether layer normalization before and after attention + feedforward
**kwargs : dict, optional, default: {}
Optional keyword arguments passed to the __init__() method of tf.keras.Model
"""
super().__init__(**kwargs)
self.mab = MultiHeadAttentionBlock(input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm)
def call(self, x, **kwargs):
"""Performs the forward pass through the self-attention layer.
Parameters
----------
x : tf.Tensor
Input of shape (batch_size, set_size, input_dim)
Returns
-------
out : tf.Tensor
Output of shape (batch_size, set_size, input_dim)
"""
return self.mab(x, x, **kwargs)
class InducedSelfAttentionBlock(tf.keras.Model):
"""Implements the ISAB block from [1] which represents learnable self-attention specifically
designed to deal with large sets via a learnable set of "inducing points".
[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).
Set transformer: A framework for attention-based permutation-invariant neural networks.
In International conference on machine learning (pp. 3744-3753). PMLR.
"""
def __init__(
self, input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, num_inducing_points, **kwargs
):
"""Creates a self-attention attention block with inducing points (ISAB) which will typically
be used as part of a set transformer architecture according to [1].
Parameters
----------
input_dim : int
The dimensionality of the input data (last axis).
attention_settings : dict
A dictionary which will be unpacked as the arguments for the ``MultiHeadAttention`` layer
See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention.
num_dense_fc : int
The number of hidden layers for the internal feedforward network
dense_settings : dict
A dictionary which will be unpacked as the arguments for the ``Dense`` layer
use_layer_norm : boolean
Whether layer normalization before and after attention + feedforward
num_inducing_points : int
The number of inducing points. Should be lower than the smallest set size
**kwargs : dict, optional, default: {}
Optional keyword arguments passed to the __init__() method of tf.keras.Model
"""
super().__init__(**kwargs)
init = tf.keras.initializers.GlorotUniform()
self.I = tf.Variable(init(shape=(num_inducing_points, input_dim)), name="I", trainable=True)
self.mab0 = MultiHeadAttentionBlock(input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm)
self.mab1 = MultiHeadAttentionBlock(input_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm)
def call(self, x, **kwargs):
"""Performs the forward pass through the self-attention layer.
Parameters
----------
x : tf.Tensor
Input of shape (batch_size, set_size, input_dim)
Returns
-------
out : tf.Tensor
Output of shape (batch_size, set_size, input_dim)
"""
batch_size = tf.shape(x)[0]
I_expanded = self.I[None, ...]
I_tiled = tf.tile(I_expanded, [batch_size, 1, 1])
h = self.mab0(I_tiled, x, **kwargs)
return self.mab1(x, h, **kwargs)
class PoolingWithAttention(tf.keras.Model):
"""Implements the pooling with multihead attention (PMA) block from [1] which represents
a permutation-invariant encoder for set-based inputs.
[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).
Set transformer: A framework for attention-based permutation-invariant neural networks.
In International conference on machine learning (pp. 3744-3753). PMLR.
"""
def __init__(
self, summary_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, num_seeds=1, **kwargs
):
"""Creates a multihead attention block (MAB) which will perform cross-attention between an input set
and a set of seed vectors (typically one for a single summary) with summary_dim output dimensions.
Could also be used as part of a ``DeepSet`` for representing learnabl instead of fixed pooling.
Parameters
----------
summary_dim : int
The dimensionality of the learned permutation-invariant representation.
attention_settings : dict
A dictionary which will be unpacked as the arguments for the ``MultiHeadAttention`` layer
See https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention.
num_dense_fc : int
The number of hidden layers for the internal feedforward network
dense_settings : dict
A dictionary which will be unpacked as the arguments for the ``Dense`` layer
use_layer_norm : boolean
Whether layer normalization before and after attention + feedforward
num_seeds : int, optional, default: 1
The number of "seed vectors" to use. Each seed vector represents a permutation-invariant
summary of the entire set. If you use ``num_seeds > 1``, the resulting seeds will be flattened
into a 2-dimensional output, which will have a dimensionality of ``num_seeds * summary_dim``
**kwargs : dict, optional, default: {}
Optional keyword arguments passed to the __init__() method of tf.keras.Model
"""
super().__init__(**kwargs)
self.mab = MultiHeadAttentionBlock(
summary_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, **kwargs
)
init = tf.keras.initializers.GlorotUniform()
self.seed_vec = tf.Variable(init(shape=(num_seeds, summary_dim)), name="seed_vec", trainable=True)
self.fc = Sequential([Dense(**dense_settings) for _ in range(num_dense_fc)])
self.fc.add(Dense(summary_dim))
def call(self, x, **kwargs):
"""Performs the forward pass through the PMA block.
Parameters
----------
x : tf.Tensor
Input of shape (batch_size, set_size, input_dim)
Returns
-------
out : tf.Tensor
Output of shape (batch_size, num_seeds * summary_dim)
"""
out = self.fc(x)
batch_size = tf.shape(x)[0]
seed_expanded = self.seed_vec[None, ...]
seed_tiled = tf.tile(seed_expanded, [batch_size, 1, 1])
out = self.mab(seed_tiled, out, **kwargs)
return tf.reshape(out, (tf.shape(out)[0], -1))