-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathabstracters.py
382 lines (294 loc) · 12.7 KB
/
abstracters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
"""
Module Implementing the 'abstracter' (RelationalAbstracter).
We also implement an "ablation module" and an early experimental variant called SymbolicAbstracter.
The abstracter is a module for transformer-based models which aims to encourage
learning abstract relations.
It is characterized by employing learned input-independent 'symbols' in its computation
and using an attention mechanism which enforced the representation of purely relational
information.
"""
import tensorflow as tf
from transformer_modules import AddPositionalEmbedding, FeedForward
from attention import GlobalSelfAttention, BaseAttention, RelationalAttention, SymbolicAttention, CrossAttention
class RelationalAbstracter(tf.keras.layers.Layer):
"""
An implementation of the 'Abstractor' module.
This implementation uses tensorflow's MultiHeadAttention layer
to implement relational cross-attention.
"""
def __init__(
self,
num_layers,
num_heads,
dff,
use_pos_embedding=True,
use_learned_symbols=True,
mha_activation_type='softmax',
use_self_attn=True,
dropout_rate=0.1,
name=None):
"""
Parameters
----------
num_layers : int
number of layers
num_heads : int
number of 'heads' in relational cross-attention (relation dimension)
dff : int
dimension of intermediate layer in feedforward network
use_pos_embedding : bool, optional
whether to add positional embeddings to symbols, by default True
use_learned_symbols : bool, optional
whether to use learned symbols or nonparametric positional embeddings, by default True
mha_activation_type : str, optional
activation of MHA in relational cross-attention, by default 'softmax'
use_self_attn : bool, optional
whether to apply self-attention in addition to relational cross-attn, by default True
dropout_rate : float, optional
dropout rate, by default 0.1
name : str, optional
name of layer, by default None
"""
super(RelationalAbstracter, self).__init__(name=name)
self.num_layers = num_layers
self.num_heads = num_heads
self.dff = dff
self.mha_activation_type = mha_activation_type
self.use_pos_embedding = use_pos_embedding
self.use_learned_symbols = use_learned_symbols
if not self.use_learned_symbols:
self.use_pos_embedding = True
self.use_self_attn = use_self_attn
self.dropout_rate = dropout_rate
def build(self, input_shape):
_, self.sequence_length, self.d_model = input_shape
# define the input-independent symbolic input vector sequence at the decoder
if self.use_learned_symbols:
normal_initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1.)
self.symbol_sequence = tf.Variable(
normal_initializer(shape=(self.sequence_length, self.d_model)),
name='symbols', trainable=True)
# layer which adds positional embedding (to be used on symbol sequence)
if self.use_pos_embedding:
self.add_pos_embedding = AddPositionalEmbedding()
self.dropout = tf.keras.layers.Dropout(self.dropout_rate)
self.abstracter_layers = [
RelationalAbstracterLayer(d_model=self.d_model, num_heads=self.num_heads,
dff=self.dff, mha_activation_type=self.mha_activation_type, use_self_attn=self.use_self_attn,
dropout_rate=self.dropout_rate)
for _ in range(self.num_layers)]
self.last_attn_scores = None
def call(self, inputs):
# symbol sequence is input independent, so use the same one for all computations in the given batch
# (this broadcasts the symbol_sequence across all inputs in the batch)
symbol_seq = tf.zeros_like(inputs)
if self.use_learned_symbols:
symbol_seq = symbol_seq + self.symbol_sequence
# add positional embedding
if self.use_pos_embedding:
symbol_seq = self.add_pos_embedding(symbol_seq)
symbol_seq = self.dropout(symbol_seq)
for i in range(self.num_layers):
symbol_seq = self.abstracter_layers[i](symbol_seq, inputs)
# self.last_attn_scores = self.dec_layers[-1].last_attn_scores
return symbol_seq
class RelationalAbstracterLayer(tf.keras.layers.Layer):
def __init__(self,
*,
d_model,
num_heads,
dff,
use_self_attn=True,
mha_activation_type='softmax',
dropout_rate=0.1):
super(RelationalAbstracterLayer, self).__init__()
self.use_self_attn = use_self_attn
if self.use_self_attn:
self.self_attention = GlobalSelfAttention(
num_heads=num_heads,
key_dim=d_model,
activation_type=mha_activation_type,
dropout=dropout_rate)
self.relational_crossattention = RelationalAttention(
num_heads=num_heads,
key_dim=d_model,
activation_type=mha_activation_type,
dropout=dropout_rate)
self.dff = dff
if dff is not None:
self.ffn = FeedForward(d_model, dff)
def call(self, symbols, objects):
if self.use_self_attn:
symbols = self.self_attention(symbols)
symbols = self.relational_crossattention(symbols=symbols, inputs=objects)
# Cache the last attention scores for plotting later
self.last_attn_scores = self.relational_crossattention.last_attn_scores
if self.dff is not None:
symbols = self.ffn(symbols) # Shape `(batch_size, seq_len, d_model)`.
return symbols
# The SymbolicAbstracter is an early experimental variant.
# It does not appear in the paper. We leave it here for reference.
class SymbolicAbstracter(tf.keras.layers.Layer):
"""
A variant of an 'Abstractor' module early in development.
This variant uses a 'symbolic' attention mechanism, in which
Q <- S, K <- X, V <- X, where X is the input and S are learned symbols.
"""
def __init__(
self,
num_layers,
num_heads,
dff,
use_pos_embedding=True,
mha_activation_type='softmax',
dropout_rate=0.1,
name='symbolic_abstracter'):
super(SymbolicAbstracter, self).__init__(name=name)
self.num_layers = num_layers
self.num_heads = num_heads
self.dff = dff
self.use_pos_embedding = use_pos_embedding
self.mha_activation_type = mha_activation_type
self.dropout_rate = dropout_rate
def build(self, input_shape):
_, self.sequence_length, self.d_model = input_shape
# define the input-independent symbolic input vector sequence at the decoder
normal_initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1.)
self.symbol_sequence = tf.Variable(
normal_initializer(shape=(self.sequence_length, self.d_model)),
name='symbols', trainable=True)
# layer which adds positional embedding (to be used on symbol sequence)
if self.use_pos_embedding:
self.add_pos_embedding = AddPositionalEmbedding()
self.dropout = tf.keras.layers.Dropout(self.dropout_rate)
self.abstracter_layers = [
SymbolicAbstracterLayer(d_model=self.d_model, num_heads=self.num_heads,
dff=self.dff, mha_activation_type=self.mha_activation_type,
dropout_rate=self.dropout_rate)
for _ in range(self.num_layers)]
self.last_attn_scores = None
def call(self, encoder_context):
# symbol sequence is input independent, so use the same one for all computations in the given batch
symbol_seq = tf.zeros_like(encoder_context) + self.symbol_sequence
# add positional embedding
if self.use_pos_embedding:
symbol_seq = self.add_pos_embedding(symbol_seq)
symbol_seq = self.dropout(symbol_seq)
for i in range(self.num_layers):
symbol_seq = self.abstracter_layers[i](symbol_seq, encoder_context)
# self.last_attn_scores = self.dec_layers[-1].last_attn_scores
return symbol_seq
class SymbolicAbstracterLayer(tf.keras.layers.Layer):
def __init__(
self,
d_model,
num_heads,
dff,
mha_activation_type='softmax',
dropout_rate=0.1,
name=None):
super(SymbolicAbstracterLayer, self).__init__(name=name)
self.mha_activation_type = mha_activation_type
self.self_attention = GlobalSelfAttention(
num_heads=num_heads,
key_dim=d_model,
activation_type=mha_activation_type,
dropout=dropout_rate)
self.symbolic_attention = SymbolicAttention(
num_heads=num_heads,
key_dim=d_model,
activation_type=mha_activation_type,
dropout=dropout_rate)
self.ffn = FeedForward(d_model, dff)
def call(self, x, context):
x = self.self_attention(x=x)
x = self.symbolic_attention(x=x, context=context)
# Cache the last attention scores for plotting later
self.last_attn_scores = self.symbolic_attention.last_attn_scores
x = self.ffn(x) # Shape `(batch_size, seq_len, d_model)`.
return x
class AblationAbstractor(tf.keras.layers.Layer):
"""
An 'Ablation' Abstractor model.
This model is the same as the RelationalAbstractor, but uses
standard cross-attention instead of relational cross-attention.
This is used to isolate for the effect of the cross-attention scheme
in experiments.
"""
def __init__(
self,
num_layers,
num_heads,
dff,
use_self_attn=True,
use_pos_embedding=True,
mha_activation_type='softmax',
dropout_rate=0.1,
name='ablation_model'):
super(AblationAbstractor, self).__init__(name=name)
self.num_layers = num_layers
self.num_heads = num_heads
self.dff = dff
self.use_self_attn = use_self_attn
self.mha_activation_type = mha_activation_type
self.use_pos_embedding = use_pos_embedding
self.dropout_rate = dropout_rate
def build(self, input_shape):
_, self.sequence_length, self.d_model = input_shape
# define the input-independent symbolic input vector sequence at the decoder
normal_initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1.)
self.symbol_sequence = tf.Variable(
normal_initializer(shape=(self.sequence_length, self.d_model)),
name='symbols', trainable=True)
# layer which adds positional embedding (to be used on symbol sequence)
if self.use_pos_embedding:
self.add_pos_embedding = AddPositionalEmbedding()
self.dropout = tf.keras.layers.Dropout(self.dropout_rate)
self.abstracter_layers = [
AblationAbstractorLayer(d_model=self.d_model, num_heads=self.num_heads,
dff=self.dff, use_self_attn=self.use_self_attn,
mha_activation_type=self.mha_activation_type,
dropout_rate=self.dropout_rate)
for _ in range(self.num_layers)]
self.last_attn_scores = None
def call(self, encoder_context):
# symbol sequence is input independent, so use the same one for all computations in the given batch
symbol_seq = tf.zeros_like(encoder_context) + self.symbol_sequence
# add positional embedding
if self.use_pos_embedding:
symbol_seq = self.add_pos_embedding(symbol_seq)
symbol_seq = self.dropout(symbol_seq)
for i in range(self.num_layers):
symbol_seq = self.abstracter_layers[i](symbol_seq, encoder_context)
# self.last_attn_scores = self.dec_layers[-1].last_attn_scores
return symbol_seq
class AblationAbstractorLayer(tf.keras.layers.Layer):
def __init__(self,
*,
d_model,
num_heads,
dff,
use_self_attn=True,
mha_activation_type='softmax',
dropout_rate=0.1):
super(AblationAbstractorLayer, self).__init__()
self.use_self_attn = use_self_attn
if use_self_attn:
self.self_attention = GlobalSelfAttention(
num_heads=num_heads,
key_dim=d_model,
dropout=dropout_rate)
self.crossattention = CrossAttention(
num_heads=num_heads,
key_dim=d_model,
activation_type=mha_activation_type,
dropout=dropout_rate)
self.ffn = FeedForward(d_model, dff)
def call(self, x, context):
if self.use_self_attn:
x = self.self_attention(x=x)
x = self.crossattention(x=x, context=context)
# Cache the last attention scores for plotting later
self.last_attn_scores = self.crossattention.last_attn_scores
x = self.ffn(x) # Shape `(batch_size, seq_len, d_model)`.
return x