-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathabstractor.py
189 lines (152 loc) · 8.01 KB
/
abstractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
This is an implementation of the Abstractor module which does not use tensorflow's MultiHeadAttention implementation.
Instead it uses the `multi_head_relation.py`. This allows some flexibility but at the cost of performance.
We later forked and modified tensorflow's implementation of MultiHeadAttention to add the flexibility we wanted
(e.g., to use a different relation activation function than softmax), which makes this implementation unnecessary.
"""
import tensorflow as tf
from tensorflow.keras import layers
from multi_head_relation import MultiHeadRelation
from transformer_modules import GlobalSelfAttention, create_positional_encoding, FeedForward
class Abstractor(tf.keras.layers.Layer):
def __init__(self,
num_layers,
rel_dim,
dff,
symbol_dim=None,
use_learned_symbols=True,
proj_dim=None,
symmetric_rels=False,
encoder_kwargs=None,
rel_activation_type='softmax',
use_self_attn=False,
use_layer_norm=False,
dropout_rate=0.,
name=None):
"""
create an Abstractor layer.
Models relations between objects via a relation tensor (from MultiHeadRelation),
and performs message-passing on a set of input-independent symbolic parameters
based on the relation tensor ("(relational) symbolic message-passing").
Unlike RelationalAbstractor, this layer does not use tensorflow's MultiHeadAttention,
instead implementing 'symbolic message-passing' directly from scratch.
Parameters
----------
num_layers : int
number of Abstractor layers (i.e., number of symbolic message-passing operations)
rel_dim : int
dimension of relations. applies to MultiHeadRelation in each layer.
symbol_dim : int, optional
dimension of symbols, by default None
use_learned_symbols: bool, optional
whether to use learned symbols or nonparametric sinusoidal symbols.
If learned, there will be a limit to the input length. by default True
proj_dim : int, optional
dimension of projections in MultiHeadRelation layers, by default None
symmetric_rels : bool, optional
whether to model relations as symmetric or not in MultiHeadRelation layers, by default False
encoder_kwargs : dict, optional
kwargs of Dense encoders in MultiHeadRelation layers, by default None
rel_activation_type : str, optional
name of activation function to use on relation tensor, by default 'softmax'
use_self_attn : bool, optional
whether or not to use self-attention, by default False
dropout_rate : float, optional
dropout rate, by default 0.
name : str, optional
name of layer, by default None
"""
super().__init__(name=name)
self.num_layers = num_layers
self.rel_dim = rel_dim
self.dff = dff
self.proj_dim = proj_dim
self.symmetric_rels = symmetric_rels
self.encoder_kwargs = encoder_kwargs
self.symbol_dim = symbol_dim
self.use_learned_symbols = use_learned_symbols
self.rel_activation_type = rel_activation_type
self.use_self_attn = use_self_attn
self.use_layer_norm = use_layer_norm
self.dropout_rate = dropout_rate
self.max_length = 1024 # TODO: make this configurable?
def build(self, input_shape):
_, self.sequence_length, self.object_dim = input_shape
self.max_length = max(self.sequence_length, self.max_length)
# symbol_dim is not given, use same dimension as objects
if self.symbol_dim is None:
self.symbol_dim = self.object_dim
if self.proj_dim is None:
self.proj_dim = self.object_dim
# define the input-independent symbolic input vector sequence
if self.use_learned_symbols:
normal_initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1.)
self.symbol_sequence = tf.Variable(
normal_initializer(shape=(self.sequence_length, self.symbol_dim)),
name='symbols', trainable=True)
else:
# create non-parametric sinusoidal symbols
self.symbol_sequence = create_positional_encoding(length=self.max_length, depth=self.symbol_dim)
if self.use_self_attn:
self.self_attention_layers = [GlobalSelfAttention(
num_heads=self.rel_dim,
key_dim=self.proj_dim,
activation_type='softmax',
dropout=self.dropout_rate) for _ in range(self.num_layers)]
# MultiHeadRelation layer for each layer of Abstractor
self.multi_head_relation_layers = [MultiHeadRelation(
rel_dim=self.rel_dim, proj_dim=self.proj_dim,
symmetric=self.symmetric_rels, dense_kwargs=self.encoder_kwargs)
for _ in range(self.num_layers)]
if self.rel_activation_type == 'softmax':
self.rel_activation = tf.keras.layers.Softmax(axis=-2)
else:
self.rel_activation = tf.keras.layers.Activation(self.rel_activation_type)
#W_o^h; output projection layers for each relation dim
self.symbol_proj_layers = [[layers.Dense(self.symbol_dim // self.rel_dim) for _ in range(self.rel_dim)] for _ in range(self.num_layers)]
# feedforward layers
self.ff_layers = [FeedForward(self.symbol_dim, self.dff) for _ in range(self.num_layers)]
if self.use_layer_norm:
self.layer_norms = [layers.LayerNormalization()]*self.num_layers
self.dropout = tf.keras.layers.Dropout(self.dropout_rate)
def call(self, inputs):
m = tf.shape(inputs)[1]
symbol_sequence = self.symbol_sequence[:m, :]
for i in range(self.num_layers):
# get relation tensor via MultiHeadRelation layer
rel_tensor = self.multi_head_relation_layers[i](inputs) # shape: [b, m, m, d_r]
# apply activation to relation tensor (e.g.: softmax)
rel_tensor = self.rel_activation(rel_tensor)
# perform symbolic message-passing based on relation tensor
# A_bijr = sum_k R_bikr S_bkj (A = S.T @ R)
if i == 0: # on first iteration, symbol equence is untransformed of shape [m, d_s]
abstract_symbol_seq = tf.einsum('bikr,kj->bijr', rel_tensor, symbol_sequence) # shape: [b, m, d_s, d_r]
else: # on next iterations, symbol sequence is transformed with shape [b, m, d_s]
abstract_symbol_seq = tf.einsum('bikr,bkj->bijr', rel_tensor, abstract_symbol_seq) # shape: [b, m, d_s, d_r]
# project and concatenate
abstract_symbol_seq = tf.concat([self.symbol_proj_layers[i][r](abstract_symbol_seq[:, :, :, r]) for r in range(self.rel_dim)], axis=2) # shape: [b, m, d_s]
# transform symbol sequence via dense layer to return to its original dimension
abstract_symbol_seq = self.ff_layers[i](abstract_symbol_seq) # shape: [b, m, d_s]
if self.use_layer_norm:
abstract_symbol_seq = self.layer_norms[i](abstract_symbol_seq)
# apply self-attention to symbol sequence
if self.use_self_attn:
# need to expand dims to add batch dim first
abstract_symbol_seq = self.self_attention_layers[i](abstract_symbol_seq) # shape [b, m, d_s]
# dropout
abstract_symbol_seq = self.dropout(abstract_symbol_seq)
return abstract_symbol_seq
def get_config(self):
config = super(Abstractor, self).get_config()
config.update(
{
'num_layers': self.num_layers,
'rel_dim': self.rel_dim,
'proj_dim': self.proj_dim,
'symmetric_rels': self.symmetric_rels,
'encoder_kwargs': self.encoder_kwargs,
'symbol_dim': self.symbol_dim,
'rel_activation_type': self.rel_activation_type,
'dropout_rate': self.dropout_rate
})
return config