From 706b1ec21bd0ebb93b8dc602e53488398a196dab Mon Sep 17 00:00:00 2001 From: Olha Wloch Date: Fri, 22 Jan 2021 14:10:29 -0500 Subject: [PATCH 1/5] Initial Changes Implementing Controllable Parameter --- keras_lmu/layers.py | 128 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 122 insertions(+), 6 deletions(-) diff --git a/keras_lmu/layers.py b/keras_lmu/layers.py index a5ecc990..e7330665 100644 --- a/keras_lmu/layers.py +++ b/keras_lmu/layers.py @@ -38,6 +38,16 @@ class to create a recurrent Keras layer to process the whole sequence. Calling information to be projected to and from the hidden layer. hidden_cell : ``tf.keras.layers.Layer`` Keras Layer/RNNCell implementing the hidden component. + controllable : bool + If False (default), the given theta is used for all + memory vectors in conjunction with ZOH discretization. + If True, Euler's method is used, and a different theta + is dynamically generated (on-the-fly) for each memory + vector by using a sigmoid layer. The theta parameter + in this cell definition becomes the initial bias + output from the sigmoid layer. In addition, the + memory vector is saturated by a tanh nonlinearity to + combat instabilities from Euler's method. hidden_to_memory : bool If True, connect the output of the hidden component back to the memory component (default False). @@ -66,12 +76,21 @@ class to create a recurrent Keras layer to process the whole sequence. Calling Filing date: 2016-08-22. """ + # When controllable is True, this scales the minimum acceptable + # theta that is needed to ensure stability with Euler's method. + # In an ideal world, this would be 1, but due to feedback loops + # through the hidden layer (or if memory_to_memory is True) + # the minimum theta needs to be scaled as a buffer. + controllable_min_theta_multiplier = 2 + + def __init__( self, memory_d, order, theta, hidden_cell, + controllable=False, hidden_to_memory=False, memory_to_memory=False, input_to_hidden=False, @@ -87,6 +106,7 @@ def __init__( self.order = order self.theta = theta self.hidden_cell = hidden_cell + self.controllable = controllable self.hidden_to_memory = hidden_to_memory self.memory_to_memory = memory_to_memory self.input_to_hidden = input_to_hidden @@ -115,20 +135,67 @@ def __init__( self.hidden_output_size = self.hidden_cell.units self.hidden_state_size = [self.hidden_cell.units] + if controllable: + # theta is factored out into a sigmoid computation in this case + theta = 1 # only affects determination of R + Q = np.arange(order, dtype=np.float64) R = (2 * Q + 1)[:, None] / theta j, i = np.meshgrid(Q, Q) A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R B = (-1.0) ** Q[:, None] * R - C = np.ones((1, order)) - D = np.zeros((1,)) - self._A, self._B, _, _, _ = cont2discrete((A, B, C, D), dt=1.0, method="zoh") + + if controllable: + self.min_theta = ( + self.compute_min_theta(A) * self.controllable_min_theta_multiplier + ) + if self.theta <= self.min_theta: + new_theta = self.min_theta + 1 # can be any epsilon > 0 + warnings.warn( + "theta (%s) must be > %s; setting to %s" + % (self.theta, self.min_theta, new_theta) + ) + self.theta = new_theta + + # Euler's method is x <- x + dt*(Ax + Bu) + # where dt = 1 / theta, with A and B kept as is. + self._A = A + self._B = B + + else: + C = np.ones((1, order)) + D = np.zeros((1,)) + + self._A, self._B, _, _, _ = cont2discrete( + (A, B, C, D), dt=1.0, method="zoh" + ) self.state_size = tf.nest.flatten(self.hidden_state_size) + [ self.memory_d * self.order ] self.output_size = self.hidden_output_size + @classmethod + def compute_min_theta(cls, A): + """Given continuous A matrix, returns the minimum theta for Euler's stability. + + Any theta less than this or equal to this value is guaranteed to be unstable. + But a theta greater than this value can still become unstable through + external feedback loops. And so this criteria is necessary, but not + sufficient, for stability. + """ + # https://gl.appliedbrainresearch.com/arvoelke/scratchpad/-/blob/master/notebooks/lmu_euler_stability.ipynb + e = np.linalg.eigvals(A) + return np.max(-np.abs(e) ** 2 / (2 * e.real)) + + def _theta_inv(self, control): + """Dynamically generates 1 / theta given a control signal.""" + assert self.controllable + # 1 / theta will be in the range (0, 1 / min_theta) + # <=> ( theta > min_theta ) + return tf.nn.sigmoid(control) / self.min_theta + + def build(self, input_shape): """ Builds the cell. @@ -182,6 +249,25 @@ def build(self, input_shape): trainable=False, ) + if self.controllable: + self.controller = self.add_weight( + name="lmu_controller", shape=(enc_d, self.memory_d), + ) + + # Solve self._theta_inv(init_controller_bias) == 1 / theta + # so that the initial control bias provides the desired initial theta_inv. + init_control = self.min_theta / self.theta + assert 0 < init_control < 1 # guaranteed by min_theta < theta + init_controller_bias = np.log(init_control / (1 - init_control)) + assert np.allclose(self._theta_inv(init_controller_bias), 1 / self.theta) + + self.controller_bias = self.add_weight( + name="lmu_controller_bias", + shape=(self.memory_d,), + initializer=tf.initializers.constant(init_controller_bias), + ) + + def call(self, inputs, states, training=None): """ Apply this cell to inputs. @@ -206,7 +292,7 @@ def call(self, inputs, states, training=None): # compute memory input u_in = tf.concat((inputs, h[0]), axis=1) if self.hidden_to_memory else inputs if self.dropout > 0: - u_in *= self.get_dropout_mask_for_cell(u_in, training) + u_in = u_in * self.get_dropout_mask_for_cell(u_in, training) u = tf.matmul(u_in, self.kernel) if self.memory_to_memory: @@ -223,8 +309,23 @@ def call(self, inputs, states, training=None): m = tf.reshape(m, (-1, self.memory_d, self.order)) u = tf.expand_dims(u, -1) - # update memory - m = tf.matmul(m, self.A) + tf.matmul(u, self.B) + # Update memory by Euler's method (controllable) or ZOH (static) + if self.controllable: + # Compute 1 / theta on the fly as a function of (inputs, h[0]) + theta_inv = self._theta_inv( + tf.matmul(u_in, self.controller) + self.controller_bias + ) # (0, 1 / min_theta) squashing to keep Euler updates stable + + # Do Euler update with dt = 1 / theta + m = m + tf.expand_dims(theta_inv, axis=2) * ( + tf.matmul(m, self.A) + u * self.B + ) + + # Also saturate the memory to combat instabilities + m = tf.nn.tanh(m) + + else: + m = tf.matmul(m, self.AT) + tf.matmul(u, self.B) # re-combine memory/order dimensions m = tf.reshape(m, (-1, self.memory_d * self.order)) @@ -265,6 +366,7 @@ def get_config(self): order=self.order, theta=self.theta, hidden_cell=tf.keras.layers.serialize(self.hidden_cell), + controllable=self.controllable, hidden_to_memory=self.hidden_to_memory, memory_to_memory=self.memory_to_memory, input_to_hidden=self.input_to_hidden, @@ -318,6 +420,16 @@ class LMU(tf.keras.layers.Layer): information to be projected to and from the hidden layer. hidden_cell : ``tf.keras.layers.Layer`` Keras Layer/RNNCell implementing the hidden component. + controllable : bool + If False (default), the given theta is used for all + memory vectors in conjunction with ZOH discretization. + If True, Euler's method is used, and a different theta + is dynamically generated (on-the-fly) for each memory + vector by using a sigmoid layer. The theta parameter + in this cell definition becomes the initial bias + output from the sigmoid layer. In addition, the + memory vector is saturated by a tanh nonlinearity to + combat instabilities from Euler's method. hidden_to_memory : bool If True, connect the output of the hidden component back to the memory component (default False). @@ -355,6 +467,7 @@ def __init__( order, theta, hidden_cell, + controllable=False, hidden_to_memory=False, memory_to_memory=False, input_to_hidden=False, @@ -372,6 +485,7 @@ def __init__( self.order = order self.theta = theta self.hidden_cell = hidden_cell + self.controllable = controllable self.hidden_to_memory = hidden_to_memory self.memory_to_memory = memory_to_memory self.input_to_hidden = input_to_hidden @@ -418,6 +532,7 @@ def build(self, input_shapes): order=self.order, theta=self.theta, hidden_cell=self.hidden_cell, + controllable = self.controllable, hidden_to_memory=self.hidden_to_memory, memory_to_memory=self.memory_to_memory, input_to_hidden=self.input_to_hidden, @@ -454,6 +569,7 @@ def get_config(self): order=self.order, theta=self.theta, hidden_cell=tf.keras.layers.serialize(self.hidden_cell), + controllable = self.controllable, hidden_to_memory=self.hidden_to_memory, memory_to_memory=self.memory_to_memory, input_to_hidden=self.input_to_hidden, From b7f564e64680c1e53e00221126955b38a4b6b9d0 Mon Sep 17 00:00:00 2001 From: Olha Wloch Date: Fri, 22 Jan 2021 14:47:27 -0500 Subject: [PATCH 2/5] Add support for additional kwargs in LMU Class --- keras_lmu/layers.py | 46 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/keras_lmu/layers.py b/keras_lmu/layers.py index e7330665..d9ebe3d9 100644 --- a/keras_lmu/layers.py +++ b/keras_lmu/layers.py @@ -451,6 +451,34 @@ class LMU(tf.keras.layers.Layer): If True, return the full output sequence. Otherwise, return just the last output in the output sequence. + ## DOUBLE CHECK THESE DEFINITIONS, taken from + ## https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN + return_state: bool, optional + If True, return the last state in addition to the output. + go_backwards: bool, optional + If True, process the input sequence backwards and return the reversed sequence. + stateful: bool, optional + If True, the last state for each sample at index i in a batch will be used + as initial state for the sample of index i in the following batch. + unroll: bool, optional + If True, the network will be unrolled, else a symbolic loop will be used. + Unrolling can speed-up a RNN, although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + time_major: bool, optional + The shape format of the inputs and outputs tensors. + If True, the inputs and outputs will be in shape (timesteps, batch, ...), + whereas in the False case, it will be (batch, timesteps, ...). + Using time_major = True is a bit more efficient because it avoids transposes + at the beginning and end of the RNN calculation. However, most TensorFlow + data is batch-major, so by default this function accepts input and emits + output in batch-major form. + zero_output_for_mask: bool, optional + Whether the output should use zeros for the masked timesteps. + Note that this field is only used when return_sequences is True + and mask is provided. It can useful if you want to reuse the raw output + sequence of the RNN without interference from the masked timesteps, + eg, merging bidirectional RNNs. + References ---------- .. [1] Voelker and Eliasmith (2018). Improving spiking dynamical @@ -476,6 +504,12 @@ def __init__( dropout=0, recurrent_dropout=0, return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + time_major=False, # unsure if False should be default + zero_output_for_mask=False, **kwargs, ): @@ -542,6 +576,12 @@ def build(self, input_shapes): recurrent_dropout=self.recurrent_dropout, ), return_sequences=self.return_sequences, + return_state=self.return_state, + go_backwards=self.go_backwards, + stateful=self.stateful, + unroll=self.unroll, + time_major=self.time_major, + zero_output_for_mask=self.zero_output_for_mask, ) self.layer.build(input_shapes) @@ -578,6 +618,12 @@ def get_config(self): dropout=self.dropout, recurrent_dropout=self.recurrent_dropout, return_sequences=self.return_sequences, + return_state=self.return_state, + go_backwards=self.go_backwards, + stateful=self.stateful, + unroll=self.unroll, + time_major=self.time_major, + zero_output_for_mask=self.zero_output_for_mask, ) ) From 580b1e80cd1404798b7980849f8a6433cc9246c6 Mon Sep 17 00:00:00 2001 From: Olha Wloch Date: Mon, 25 Jan 2021 10:10:29 -0500 Subject: [PATCH 3/5] Added Ability to Pass RNN Specific Parameters in LMU --- keras_lmu/layers.py | 54 +++++---------------------------------------- 1 file changed, 5 insertions(+), 49 deletions(-) diff --git a/keras_lmu/layers.py b/keras_lmu/layers.py index d9ebe3d9..7cbcf344 100644 --- a/keras_lmu/layers.py +++ b/keras_lmu/layers.py @@ -325,7 +325,7 @@ def call(self, inputs, states, training=None): m = tf.nn.tanh(m) else: - m = tf.matmul(m, self.AT) + tf.matmul(u, self.B) + m = tf.matmul(m, self.A) + tf.matmul(u, self.B) # re-combine memory/order dimensions m = tf.reshape(m, (-1, self.memory_d * self.order)) @@ -451,34 +451,6 @@ class LMU(tf.keras.layers.Layer): If True, return the full output sequence. Otherwise, return just the last output in the output sequence. - ## DOUBLE CHECK THESE DEFINITIONS, taken from - ## https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN - return_state: bool, optional - If True, return the last state in addition to the output. - go_backwards: bool, optional - If True, process the input sequence backwards and return the reversed sequence. - stateful: bool, optional - If True, the last state for each sample at index i in a batch will be used - as initial state for the sample of index i in the following batch. - unroll: bool, optional - If True, the network will be unrolled, else a symbolic loop will be used. - Unrolling can speed-up a RNN, although it tends to be more memory-intensive. - Unrolling is only suitable for short sequences. - time_major: bool, optional - The shape format of the inputs and outputs tensors. - If True, the inputs and outputs will be in shape (timesteps, batch, ...), - whereas in the False case, it will be (batch, timesteps, ...). - Using time_major = True is a bit more efficient because it avoids transposes - at the beginning and end of the RNN calculation. However, most TensorFlow - data is batch-major, so by default this function accepts input and emits - output in batch-major form. - zero_output_for_mask: bool, optional - Whether the output should use zeros for the masked timesteps. - Note that this field is only used when return_sequences is True - and mask is provided. It can useful if you want to reuse the raw output - sequence of the RNN without interference from the masked timesteps, - eg, merging bidirectional RNNs. - References ---------- .. [1] Voelker and Eliasmith (2018). Improving spiking dynamical @@ -504,12 +476,6 @@ def __init__( dropout=0, recurrent_dropout=0, return_sequences=False, - return_state=False, - go_backwards=False, - stateful=False, - unroll=False, - time_major=False, # unsure if False should be default - zero_output_for_mask=False, **kwargs, ): @@ -530,7 +496,7 @@ def __init__( self.return_sequences = return_sequences self.layer = None - def build(self, input_shapes): + def build(self, input_shapes, **kwargs): """ Builds the layer. @@ -576,12 +542,7 @@ def build(self, input_shapes): recurrent_dropout=self.recurrent_dropout, ), return_sequences=self.return_sequences, - return_state=self.return_state, - go_backwards=self.go_backwards, - stateful=self.stateful, - unroll=self.unroll, - time_major=self.time_major, - zero_output_for_mask=self.zero_output_for_mask, + **kwargs, ) self.layer.build(input_shapes) @@ -599,7 +560,7 @@ def call(self, inputs, training=None): return self.layer.call(inputs, training=training) - def get_config(self): + def get_config(self, **kwargs): """Return config of layer (for serialization during model saving/loading).""" config = super().get_config() @@ -618,12 +579,7 @@ def get_config(self): dropout=self.dropout, recurrent_dropout=self.recurrent_dropout, return_sequences=self.return_sequences, - return_state=self.return_state, - go_backwards=self.go_backwards, - stateful=self.stateful, - unroll=self.unroll, - time_major=self.time_major, - zero_output_for_mask=self.zero_output_for_mask, + **kwargs, ) ) From e4095d857f4d7308db55f573882ccc570eab6317 Mon Sep 17 00:00:00 2001 From: Brent Komer Date: Mon, 25 Jan 2021 15:50:33 -0500 Subject: [PATCH 4/5] adding attributes required to support bidirectional RNN --- keras_lmu/layers.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/keras_lmu/layers.py b/keras_lmu/layers.py index 7cbcf344..55f08f4f 100644 --- a/keras_lmu/layers.py +++ b/keras_lmu/layers.py @@ -450,6 +450,13 @@ class LMU(tf.keras.layers.Layer): return_sequences : bool, optional If True, return the full output sequence. Otherwise, return just the last output in the output sequence. + return_state: bool, optional + Whether to return the last state in addition to the output. + go_backwards: bool, optional + If True, process the input sequence backwards and return the reversed sequence. + stateful: bool, optional + If True, the last state for each sample at index i in a batch will be used as + initial state for the sample of index i in the following batch. References ---------- @@ -476,6 +483,9 @@ def __init__( dropout=0, recurrent_dropout=0, return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, **kwargs, ): @@ -494,9 +504,12 @@ def __init__( self.dropout = dropout self.recurrent_dropout = recurrent_dropout self.return_sequences = return_sequences + self.return_state = return_state + self.go_backwards = go_backwards + self.stateful = stateful self.layer = None - def build(self, input_shapes, **kwargs): + def build(self, input_shapes): """ Builds the layer. @@ -542,7 +555,9 @@ def build(self, input_shapes, **kwargs): recurrent_dropout=self.recurrent_dropout, ), return_sequences=self.return_sequences, - **kwargs, + return_state=self.return_state, + go_backwards=self.go_backwards, + stateful=self.stateful, ) self.layer.build(input_shapes) @@ -560,7 +575,7 @@ def call(self, inputs, training=None): return self.layer.call(inputs, training=training) - def get_config(self, **kwargs): + def get_config(self): """Return config of layer (for serialization during model saving/loading).""" config = super().get_config() @@ -579,7 +594,8 @@ def get_config(self, **kwargs): dropout=self.dropout, recurrent_dropout=self.recurrent_dropout, return_sequences=self.return_sequences, - **kwargs, + return_state=self.return_state, + go_backwards=self.go_backwards, ) ) From 2b17ecf478e9b410347f207b3d0256b132352f3c Mon Sep 17 00:00:00 2001 From: Olha Wloch Date: Tue, 26 Jan 2021 14:23:51 -0500 Subject: [PATCH 5/5] Added warnings import --- keras_lmu/layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_lmu/layers.py b/keras_lmu/layers.py index 55f08f4f..229fc28a 100644 --- a/keras_lmu/layers.py +++ b/keras_lmu/layers.py @@ -1,6 +1,7 @@ """ Core classes for the KerasLMU package. """ +import warnings import numpy as np import tensorflow as tf