Skip to content

Update JAX import paths for compatibility with version 0.5.0 and upda… #722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu

BrainPy is based on Python (>=3.8) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms.

BrainPy requires ``jax<0.6.0``.

For detailed installation instructions, please refer to the documentation: [Quickstart/Installation](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html)


Expand Down
11 changes: 7 additions & 4 deletions brainpy/_src/integrators/_jaxpr_to_source_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
import jax.numpy as jnp
import numpy as np
from jax._src.sharding_impls import UNSPECIFIED
from jax.core import Literal, Var, Jaxpr
if jax.__version__ >= '0.5.0':
from jax.extend.core import Primitive, Literal, Var, Jaxpr
else:
from jax.core import Primitive, Literal, Var, Jaxpr

__all__ = [
'fn_to_python_code',
Expand Down Expand Up @@ -187,7 +190,7 @@ def fn_to_python_code(fn, *args, **kwargs):
return source


def jaxpr_to_python_code(jaxpr: jax.core.Jaxpr,
def jaxpr_to_python_code(jaxpr: Jaxpr,
fn_name: str = "generated_function"):
"""
Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr.
Expand Down Expand Up @@ -367,7 +370,7 @@ def _maybe_wrap_fn_for_leaves(node, f, num_args):


def jaxpr_to_py_ast(state: SourcerorState,
jaxpr: jax.core.Jaxpr,
jaxpr: Jaxpr,
fn_name: str = "function"):
# Generate argument declarations
ast_args = [ast.arg(arg=state.str_name(var), annotation=None)
Expand Down Expand Up @@ -405,7 +408,7 @@ def jaxpr_to_py_ast(state: SourcerorState,
return ast.FunctionDef(name=fn_name, args=ast_args, body=stmts, decorator_list=[])


def constant_fold_jaxpr(jaxpr: jax.core.Jaxpr):
def constant_fold_jaxpr(jaxpr: Jaxpr):
"""
Given a jaxpr, return a new jaxpr with all constant folding done.
"""
Expand Down
7 changes: 6 additions & 1 deletion brainpy/_src/math/remove_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@


import jax.numpy as jnp
from jax.core import Primitive, ShapedArray
import jax
if jax.__version__ >= '0.5.0':
from jax.extend.core import Primitive
else:
from jax.core import Primitive
from jax.core import ShapedArray
from jax.interpreters import batching, mlir, xla
from .ndarray import Array

Expand Down
8 changes: 7 additions & 1 deletion brainpy/_src/math/sparse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from functools import partial
from typing import Tuple


import jax
import numpy as np
from brainpy._src.math.interoperability import as_jax
from jax import core, numpy as jnp
Expand All @@ -12,6 +14,10 @@
from jax.interpreters import mlir, ad
from jax.tree_util import tree_flatten, tree_unflatten
from jaxlib import gpu_sparse
if jax.__version__ >= '0.5.0':
from jax.extend.core import Primitive
else:
from jax.core import Primitive

__all__ = [
'coo_to_csr',
Expand Down Expand Up @@ -171,7 +177,7 @@ def _csr_to_dense_transpose(ct, data, indices, indptr, *, shape):
return _csr_extract(indices, indptr, ct), indices, indptr


csr_to_dense_p = core.Primitive('csr_to_dense')
csr_to_dense_p = Primitive('csr_to_dense')
csr_to_dense_p.def_impl(_csr_to_dense_impl)
csr_to_dense_p.def_abstract_eval(_csr_to_dense_abstract_eval)
ad.defjvp(csr_to_dense_p, _csr_to_dense_jvp, None, None)
Expand Down
6 changes: 5 additions & 1 deletion brainpy/_src/math/surrogate/_one_input_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import jax
import jax.numpy as jnp
import jax.scipy as sci
from jax.core import Primitive

if jax.__version__ >= '0.5.0':
from jax.extend.core import Primitive
else:
from jax.core import Primitive
from jax.interpreters import batching, ad, mlir

from brainpy._src.math.interoperability import as_jax
Expand Down
4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy
jax
jaxlib
jax<0.6.0
jaxlib<0.6.0
absl-py<=2.1.0
brainstate<=0.1.0.post20241210
braintaichi<=0.0.4
Expand Down
4 changes: 2 additions & 2 deletions requirements-doc.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
tqdm
jax
jaxlib
jax<0.6.0
jaxlib<0.6.0
matplotlib
numpy
scipy
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
numpy
jax
jax<0.6.0
tqdm
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
author_email='[email protected]',
packages=packages,
python_requires='>=3.9',
install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm'],
install_requires=['numpy>=1.15', 'jax>=0.4.13,<0.6.0', 'tqdm'],
url='https://github.com/brainpy/BrainPy',
project_urls={
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",
Expand Down
Loading