Skip to content

Commit

Permalink
Updated to use eqxi.GetKey and tree_allclose
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 27, 2023
1 parent 494e7e7 commit 78700c0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 69 deletions.
28 changes: 2 additions & 26 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import random

import equinox.internal as eqxi
import jax
import jax.random as jr
import pytest
from jaxtyping import PRNGKeyArray


jax.config.update("jax_enable_x64", True)


# This offers reproducability -- the initial seed is printed in the repr so we can see
# it when a test fails.
# Note the `eq=False`, which means that `_GetKey `objects have `__eq__` and `__hash__`
# based on object identity.
@dataclasses.dataclass(eq=False)
class _GetKey:
seed: int
call: int
key: PRNGKeyArray

def __init__(self, seed: int):
self.seed = seed
self.call = 0
self.key = jr.PRNGKey(seed)

def __call__(self):
self.call += 1
return jr.fold_in(self.key, self.call)


@pytest.fixture
def getkey():
return _GetKey(random.randint(0, 2**31 - 1))
return eqxi.GetKey()
45 changes: 2 additions & 43 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# limitations under the License.

import functools as ft
import operator
import random
from collections.abc import Callable
from typing import Any, TypeVar

Expand All @@ -39,47 +37,8 @@
Aux = TypeVar("Aux")


def getkey():
return jr.PRNGKey(random.randint(0, 2**31 - 1))


def _tree_allclose(x, y, **kwargs):
if type(x) is not type(y):
return False
if isinstance(x, jnp.ndarray): # pyright: ignore
if jnp.issubdtype(x.dtype, jnp.inexact):
return (
x.shape == y.shape
and x.dtype == y.dtype
and jnp.allclose(x, y, **kwargs)
)
else:
return x.shape == y.shape and x.dtype == y.dtype and jnp.all(x == y)
elif isinstance(x, np.ndarray):
if np.issubdtype(x.dtype, np.inexact):
return (
x.shape == y.shape
and x.dtype == y.dtype
and np.allclose(x, y, **kwargs)
)
else:
return x.shape == y.shape and x.dtype == y.dtype and np.all(x == y)
elif isinstance(x, jax.ShapeDtypeStruct):
assert x.shape == y.shape and x.dtype == y.dtype
else:
return x == y


def tree_allclose(x, y, **kwargs):
"""As `jnp.allclose`, except:
- It also supports PyTree arguments.
- It mandates that shapes match as well (no broadcasting)
"""
same_structure = jtu.tree_structure(x) == jtu.tree_structure(y)
allclose = ft.partial(_tree_allclose, **kwargs)
return same_structure and jtu.tree_reduce(
operator.and_, jtu.tree_map(allclose, x, y), True
)
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)


def finite_difference_jvp(fn, primals, tangents, eps=None, **kwargs):
Expand Down

0 comments on commit 78700c0

Please sign in to comment.