-
Notifications
You must be signed in to change notification settings - Fork 137
Open
Labels
Description
Description
pytensor/pytensor/link/jax/dispatch/basic.py
Lines 22 to 28 in 0e29d76
@singledispatch | |
def jax_typify(data, dtype=None, **kwargs): | |
r"""Convert instances of PyTensor `Type`\s to JAX types.""" | |
if dtype is None: | |
return data | |
else: | |
return jnp.array(data, dtype=dtype) |
(And now from copying the approach also the PyTorch Linker)