Skip to content

Commit

Permalink
Avoid JIT compilation errors
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Jul 11, 2024
1 parent ceb2d34 commit 2f2e512
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
python-version: "3.11"
coverage-name: latest
- script: test
tf-version: tensorflow~=2.6.0
tf-version: tensorflow~=2.8.0
python-version: "3.8"
coverage-name: oldest
- script: test
Expand Down
3 changes: 1 addition & 2 deletions .nengobones.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ manifest_in: {}

setup_py:
install_req:
- anyio<4 # not compatible with older tensorflow versions
- packaging>=20.9
- scipy>=1.0.0
- tensorflow>=2.6.0
- tensorflow>=2.8.0
tests_req:
- pytest>=6.1.0
- pytest-rng>=1.0.0
Expand Down
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Release history
0.8.0 (July 10, 2024)
=====================

*Compatible with TensorFlow 2.6 - 2.16*
*Compatible with TensorFlow 2.8 - 2.16*

**Added**

Expand Down
4 changes: 1 addition & 3 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
DropoutRNNCell as DropoutRNNCellMixin,
)

if tf_version < version.parse("2.8.0rc0"):
from tensorflow.keras.layers import Layer as BaseRandomLayer
elif tf_version < version.parse("2.13.0rc0"):
if tf_version < version.parse("2.13.0rc0"):
from keras.engine.base_layer import BaseRandomLayer
elif tf_version < version.parse("2.16.0rc0"):
from keras.src.engine.base_layer import BaseRandomLayer
Expand Down
3 changes: 1 addition & 2 deletions keras_lmu/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


def pytest_configure(config):
if version.parse(tf.__version__) >= version.parse("2.7.0"):
tf.debugging.disable_traceback_filtering()
tf.debugging.disable_traceback_filtering()
if version.parse(tf.__version__) >= version.parse("2.16.0"):
keras.config.disable_traceback_filtering()
13 changes: 11 additions & 2 deletions keras_lmu/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def test_multivariate_lmu(rng, discretizer):
@pytest.mark.parametrize("has_input_kernel", (True, False))
@pytest.mark.parametrize("feedforward", (True, False))
@pytest.mark.parametrize("discretizer", ("zoh", "euler"))
def test_layer_vs_cell(rng, has_input_kernel, feedforward, discretizer):
def test_layer_vs_cell(rng, has_input_kernel, feedforward, discretizer, seed):
keras.utils.set_random_seed(seed)

n_steps = 10
input_d = 32
kwargs = {
Expand Down Expand Up @@ -534,6 +536,8 @@ def test_fit(feedforward, discretizer, trainable_theta):
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer="adam",
metrics=["accuracy"],
# can't JIT compile the `tf.linalg.expm` operation used in this particular case
jit_compile=not (trainable_theta and discretizer == "zoh"),
)

model.fit(x_train, y_train, epochs=10, validation_split=0.2)
Expand Down Expand Up @@ -633,7 +637,12 @@ def test_theta_update(discretizer, trainable_theta, tmp_path):
lmu = keras.layers.RNN(lmu_cell)(inputs)
model = keras.Model(inputs=inputs, outputs=lmu)

model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
model.compile(
loss=keras.losses.MeanSquaredError(),
optimizer="adam",
# can't JIT compile the `tf.linalg.expm` operation used in this particular case
jit_compile=not (trainable_theta and discretizer == "zoh"),
)

# make sure theta_inv is set correctly to initial value
assert np.allclose(lmu_cell.theta_inv.numpy(), 1 / theta)
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ def read(*filenames, **kwargs):
version = runpy.run_path(str(root / "keras_lmu" / "version.py"))["version"]

install_req = [
"anyio<4",
"packaging>=20.9",
"scipy>=1.0.0",
"tensorflow>=2.6.0",
"tensorflow>=2.8.0",
]
docs_req = [
"matplotlib>=3.8.4",
Expand Down

0 comments on commit 2f2e512

Please sign in to comment.