Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: LMU controllable #33

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 141 additions & 6 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Core classes for the KerasLMU package.
"""
import warnings

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -38,6 +39,16 @@ class to create a recurrent Keras layer to process the whole sequence. Calling
information to be projected to and from the hidden layer.
hidden_cell : ``tf.keras.layers.Layer``
Keras Layer/RNNCell implementing the hidden component.
controllable : bool
If False (default), the given theta is used for all
memory vectors in conjunction with ZOH discretization.
If True, Euler's method is used, and a different theta
is dynamically generated (on-the-fly) for each memory
vector by using a sigmoid layer. The theta parameter
in this cell definition becomes the initial bias
output from the sigmoid layer. In addition, the
memory vector is saturated by a tanh nonlinearity to
combat instabilities from Euler's method.
hidden_to_memory : bool
If True, connect the output of the hidden component back to the memory
component (default False).
Expand Down Expand Up @@ -66,12 +77,21 @@ class to create a recurrent Keras layer to process the whole sequence. Calling
Filing date: 2016-08-22.
"""

# When controllable is True, this scales the minimum acceptable
# theta that is needed to ensure stability with Euler's method.
# In an ideal world, this would be 1, but due to feedback loops
# through the hidden layer (or if memory_to_memory is True)
# the minimum theta needs to be scaled as a buffer.
controllable_min_theta_multiplier = 2


def __init__(
self,
memory_d,
order,
theta,
hidden_cell,
controllable=False,
hidden_to_memory=False,
memory_to_memory=False,
input_to_hidden=False,
Expand All @@ -87,6 +107,7 @@ def __init__(
self.order = order
self.theta = theta
self.hidden_cell = hidden_cell
self.controllable = controllable
self.hidden_to_memory = hidden_to_memory
self.memory_to_memory = memory_to_memory
self.input_to_hidden = input_to_hidden
Expand Down Expand Up @@ -115,20 +136,67 @@ def __init__(
self.hidden_output_size = self.hidden_cell.units
self.hidden_state_size = [self.hidden_cell.units]

if controllable:
# theta is factored out into a sigmoid computation in this case
theta = 1 # only affects determination of R

Q = np.arange(order, dtype=np.float64)
R = (2 * Q + 1)[:, None] / theta
j, i = np.meshgrid(Q, Q)
A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R
B = (-1.0) ** Q[:, None] * R
C = np.ones((1, order))
D = np.zeros((1,))
self._A, self._B, _, _, _ = cont2discrete((A, B, C, D), dt=1.0, method="zoh")

if controllable:
self.min_theta = (
self.compute_min_theta(A) * self.controllable_min_theta_multiplier
)
if self.theta <= self.min_theta:
new_theta = self.min_theta + 1 # can be any epsilon > 0
warnings.warn(
"theta (%s) must be > %s; setting to %s"
% (self.theta, self.min_theta, new_theta)
)
self.theta = new_theta

# Euler's method is x <- x + dt*(Ax + Bu)
# where dt = 1 / theta, with A and B kept as is.
self._A = A
self._B = B

else:
C = np.ones((1, order))
D = np.zeros((1,))

self._A, self._B, _, _, _ = cont2discrete(
(A, B, C, D), dt=1.0, method="zoh"
)

self.state_size = tf.nest.flatten(self.hidden_state_size) + [
self.memory_d * self.order
]
self.output_size = self.hidden_output_size

@classmethod
def compute_min_theta(cls, A):
"""Given continuous A matrix, returns the minimum theta for Euler's stability.

Any theta less than this or equal to this value is guaranteed to be unstable.
But a theta greater than this value can still become unstable through
external feedback loops. And so this criteria is necessary, but not
sufficient, for stability.
"""
# https://gl.appliedbrainresearch.com/arvoelke/scratchpad/-/blob/master/notebooks/lmu_euler_stability.ipynb
e = np.linalg.eigvals(A)
return np.max(-np.abs(e) ** 2 / (2 * e.real))

def _theta_inv(self, control):
"""Dynamically generates 1 / theta given a control signal."""
assert self.controllable
# 1 / theta will be in the range (0, 1 / min_theta)
# <=> ( theta > min_theta )
return tf.nn.sigmoid(control) / self.min_theta


def build(self, input_shape):
"""
Builds the cell.
Expand Down Expand Up @@ -182,6 +250,25 @@ def build(self, input_shape):
trainable=False,
)

if self.controllable:
self.controller = self.add_weight(
name="lmu_controller", shape=(enc_d, self.memory_d),
)

# Solve self._theta_inv(init_controller_bias) == 1 / theta
# so that the initial control bias provides the desired initial theta_inv.
init_control = self.min_theta / self.theta
assert 0 < init_control < 1 # guaranteed by min_theta < theta
init_controller_bias = np.log(init_control / (1 - init_control))
assert np.allclose(self._theta_inv(init_controller_bias), 1 / self.theta)

self.controller_bias = self.add_weight(
name="lmu_controller_bias",
shape=(self.memory_d,),
initializer=tf.initializers.constant(init_controller_bias),
)


def call(self, inputs, states, training=None):
"""
Apply this cell to inputs.
Expand All @@ -206,7 +293,7 @@ def call(self, inputs, states, training=None):
# compute memory input
u_in = tf.concat((inputs, h[0]), axis=1) if self.hidden_to_memory else inputs
if self.dropout > 0:
u_in *= self.get_dropout_mask_for_cell(u_in, training)
u_in = u_in * self.get_dropout_mask_for_cell(u_in, training)
u = tf.matmul(u_in, self.kernel)

if self.memory_to_memory:
Expand All @@ -223,8 +310,23 @@ def call(self, inputs, states, training=None):
m = tf.reshape(m, (-1, self.memory_d, self.order))
u = tf.expand_dims(u, -1)

# update memory
m = tf.matmul(m, self.A) + tf.matmul(u, self.B)
# Update memory by Euler's method (controllable) or ZOH (static)
if self.controllable:
# Compute 1 / theta on the fly as a function of (inputs, h[0])
theta_inv = self._theta_inv(
tf.matmul(u_in, self.controller) + self.controller_bias
) # (0, 1 / min_theta) squashing to keep Euler updates stable

# Do Euler update with dt = 1 / theta
m = m + tf.expand_dims(theta_inv, axis=2) * (
tf.matmul(m, self.A) + u * self.B
)

# Also saturate the memory to combat instabilities
m = tf.nn.tanh(m)

else:
m = tf.matmul(m, self.A) + tf.matmul(u, self.B)

# re-combine memory/order dimensions
m = tf.reshape(m, (-1, self.memory_d * self.order))
Expand Down Expand Up @@ -265,6 +367,7 @@ def get_config(self):
order=self.order,
theta=self.theta,
hidden_cell=tf.keras.layers.serialize(self.hidden_cell),
controllable=self.controllable,
hidden_to_memory=self.hidden_to_memory,
memory_to_memory=self.memory_to_memory,
input_to_hidden=self.input_to_hidden,
Expand Down Expand Up @@ -318,6 +421,16 @@ class LMU(tf.keras.layers.Layer):
information to be projected to and from the hidden layer.
hidden_cell : ``tf.keras.layers.Layer``
Keras Layer/RNNCell implementing the hidden component.
controllable : bool
If False (default), the given theta is used for all
memory vectors in conjunction with ZOH discretization.
If True, Euler's method is used, and a different theta
is dynamically generated (on-the-fly) for each memory
vector by using a sigmoid layer. The theta parameter
in this cell definition becomes the initial bias
output from the sigmoid layer. In addition, the
memory vector is saturated by a tanh nonlinearity to
combat instabilities from Euler's method.
hidden_to_memory : bool
If True, connect the output of the hidden component back to the memory
component (default False).
Expand All @@ -338,6 +451,13 @@ class LMU(tf.keras.layers.Layer):
return_sequences : bool, optional
If True, return the full output sequence. Otherwise, return just the last
output in the output sequence.
return_state: bool, optional
Whether to return the last state in addition to the output.
go_backwards: bool, optional
If True, process the input sequence backwards and return the reversed sequence.
stateful: bool, optional
If True, the last state for each sample at index i in a batch will be used as
initial state for the sample of index i in the following batch.

References
----------
Expand All @@ -355,6 +475,7 @@ def __init__(
order,
theta,
hidden_cell,
controllable=False,
hidden_to_memory=False,
memory_to_memory=False,
input_to_hidden=False,
Expand All @@ -363,6 +484,9 @@ def __init__(
dropout=0,
recurrent_dropout=0,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
**kwargs,
):

Expand All @@ -372,6 +496,7 @@ def __init__(
self.order = order
self.theta = theta
self.hidden_cell = hidden_cell
self.controllable = controllable
self.hidden_to_memory = hidden_to_memory
self.memory_to_memory = memory_to_memory
self.input_to_hidden = input_to_hidden
Expand All @@ -380,6 +505,9 @@ def __init__(
self.dropout = dropout
self.recurrent_dropout = recurrent_dropout
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
self.layer = None

def build(self, input_shapes):
Expand Down Expand Up @@ -418,6 +546,7 @@ def build(self, input_shapes):
order=self.order,
theta=self.theta,
hidden_cell=self.hidden_cell,
controllable = self.controllable,
hidden_to_memory=self.hidden_to_memory,
memory_to_memory=self.memory_to_memory,
input_to_hidden=self.input_to_hidden,
Expand All @@ -427,6 +556,9 @@ def build(self, input_shapes):
recurrent_dropout=self.recurrent_dropout,
),
return_sequences=self.return_sequences,
return_state=self.return_state,
go_backwards=self.go_backwards,
stateful=self.stateful,
Copy link
Contributor

@arvoelke arvoelke Mar 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note these flags are being ignored if the LMUFFT is being selected (see if branch above). It should only be selected if the flags are the default values supported by LMUFFT. That said, the LMUFFT can likely be extended to support them.

)

self.layer.build(input_shapes)
Expand Down Expand Up @@ -454,6 +586,7 @@ def get_config(self):
order=self.order,
theta=self.theta,
hidden_cell=tf.keras.layers.serialize(self.hidden_cell),
controllable = self.controllable,
hidden_to_memory=self.hidden_to_memory,
memory_to_memory=self.memory_to_memory,
input_to_hidden=self.input_to_hidden,
Expand All @@ -462,6 +595,8 @@ def get_config(self):
dropout=self.dropout,
recurrent_dropout=self.recurrent_dropout,
return_sequences=self.return_sequences,
return_state=self.return_state,
go_backwards=self.go_backwards,
)
)

Expand Down