Skip to content

Commit 2a5adea

Browse files
[bug] Fixing compatibility issues with jax (#691)
* Add code for processing and computing with progress bar * fix * Revert "Add code for processing and computing with progress bar" This reverts commit 12977af. * Add JAX configuration for disabling async dispatch * Update controls.py * Update brainpy-changelog.md --------- Co-authored-by: Chaoming Wang <[email protected]>
1 parent a51e3a7 commit 2a5adea

File tree

4 files changed

+92
-13
lines changed

4 files changed

+92
-13
lines changed

brainpy-changelog.md

+74
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,80 @@
22

33

44
## brainpy>2.3.x
5+
### Version 2.6.1
6+
#### Breaking Changes
7+
- Fixing compatibility issues between `numpy` and `jax`
8+
9+
#### What's Changed
10+
* [doc] Add Chinese version of `operator_custom_with_cupy.ipynb` and Rename it's title by @Routhleck in https://github.com/brainpy/BrainPy/pull/659
11+
* Fix "amsgrad" is used before being defined when initializing the AdamW optimizer by @CloudyDory in https://github.com/brainpy/BrainPy/pull/660
12+
* fix issue #661 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/662
13+
* fix flax RNN interoperation, fix #663 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/665
14+
* [fix] Replace jax.experimental.host_callback with jax.pure_callback by @Routhleck in https://github.com/brainpy/BrainPy/pull/670
15+
* [math] Update `CustomOpByNumba` to support JAX version >= 0.4.24 by @Routhleck in https://github.com/brainpy/BrainPy/pull/669
16+
* [math] Fix `CustomOpByNumba` on `multiple_results=True` by @Routhleck in https://github.com/brainpy/BrainPy/pull/671
17+
* [math] Implementing event-driven sparse matrix @ matrix operators by @Routhleck in https://github.com/brainpy/BrainPy/pull/613
18+
* [math] Add getting JIT connect matrix method for `brainpy.dnn.linear` by @Routhleck in https://github.com/brainpy/BrainPy/pull/672
19+
* [math] Add get JIT weight matrix methods(Uniform & Normal) for `brainpy.dnn.linear` by @Routhleck in https://github.com/brainpy/BrainPy/pull/673
20+
* support `Integrator.to_math_expr()` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/674
21+
* [bug] Replace `collections.Iterable` with `collections.abc.Iterable` by @Routhleck in https://github.com/brainpy/BrainPy/pull/677
22+
* Fix surrogate gradient function and numpy 2.0 compatibility by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/679
23+
* :arrow_up: Bump docker/build-push-action from 5 to 6 by @dependabot in https://github.com/brainpy/BrainPy/pull/678
24+
* fix the incorrect verbose of `clear_name_cache()` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/681
25+
* [bug] Fix prograss bar is not displayed and updated as expected by @Routhleck in https://github.com/brainpy/BrainPy/pull/683
26+
* Fix autograd by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/687
27+
28+
29+
**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.6.0...V2.6.1
30+
31+
### Version 2.6.0
32+
33+
#### New Features
34+
35+
This release provides several new features, including:
36+
37+
- ``MLIR`` registered operator customization interface in ``brainpy.math.XLACustomOp``.
38+
- Operator customization with CuPy JIT interface.
39+
- Bug fixes.
40+
41+
42+
43+
#### What's Changed
44+
* [doc] Fix the wrong path of more examples of `operator customized with taichi.ipynb` by @Routhleck in https://github.com/brainpy/BrainPy/pull/612
45+
* [docs] Add colab link for documentation notebooks by @Routhleck in https://github.com/brainpy/BrainPy/pull/614
46+
* Update requirements-doc.txt to fix doc building temporally by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/617
47+
* [math] Rebase operator customization using MLIR registration interface by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/618
48+
* [docs] Add kaggle link for documentation notebooks by @Routhleck in https://github.com/brainpy/BrainPy/pull/619
49+
* update requirements by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/620
50+
* require `brainpylib>=0.2.6` for `jax>=0.4.24` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/622
51+
* [tools] add `brainpy.tools.compose` and `brainpy.tools.pipe` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/624
52+
* doc hierarchy update by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/630
53+
* Standardizing and generalizing object-oriented transformations by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/628
54+
* fix #626 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/631
55+
* Fix delayvar not correct in concat mode by @CloudyDory in https://github.com/brainpy/BrainPy/pull/632
56+
* [dependency] remove hard dependency of `taichi` and `numba` by @Routhleck in https://github.com/brainpy/BrainPy/pull/635
57+
* `clear_buffer_memory()` support clearing `array`, `compilation`, and `names` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/639
58+
* add `brainpy.math.surrogate..Surrogate` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/638
59+
* Enable brainpy object as pytree so that it can be applied with ``jax.jit`` etc. directly by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/625
60+
* Fix ci by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/640
61+
* Clean taichi AOT caches by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/643
62+
* [ci] Fix windows pytest fatal exception by @Routhleck in https://github.com/brainpy/BrainPy/pull/644
63+
* [math] Support more than 8 parameters of taichi gpu custom operator definition by @Routhleck in https://github.com/brainpy/BrainPy/pull/642
64+
* Doc for ``brainpylib>=0.3.0`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/645
65+
* Find back updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/646
66+
* Update installation instruction by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/651
67+
* Fix delay bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/650
68+
* update doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/652
69+
* [math] Add new customize operators with `cupy` by @Routhleck in https://github.com/brainpy/BrainPy/pull/653
70+
* [math] Fix taichi custom operator on gpu backend by @Routhleck in https://github.com/brainpy/BrainPy/pull/655
71+
* update cupy operator custom doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/656
72+
* version 2.6.0 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/657
73+
* Upgrade CI by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/658
74+
75+
## New Contributors
76+
* @CloudyDory made their first contribution in https://github.com/brainpy/BrainPy/pull/632
77+
78+
**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.5.0...V2.6.0
579

680

781
### Version 2.5.0

brainpy/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,8 @@
153153

154154
del deprecation_getattr2
155155

156+
# jax config
157+
import os
158+
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
159+
import jax
160+
jax.config.update('jax_cpu_enable_async_dispatch', False)

brainpy/_src/math/object_transform/controls.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ def fun2scan(carry, x):
915915
dyn_vars[k]._value = dyn_vars_data[k]
916916
carry, results = body_fun(carry, x)
917917
if progress_bar:
918-
jax.pure_callback(lambda *arg: bar.update(), ())
918+
jax.debug.callback(lambda *arg: bar.update(), ())
919919
carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
920920
return (dyn_vars.dict_data(), carry), results
921921

examples/dynamics_simulation/hh_model.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,16 @@ def __init__(self, size):
4343
self.KNa.add_elem()
4444

4545

46-
# hh = HH(1)
47-
# I, length = bp.inputs.section_input(values=[0, 5, 0],
48-
# durations=[100, 500, 100],
49-
# return_length=True)
50-
# runner = bp.DSRunner(
51-
# hh,
52-
# monitors=['V', 'INa.p', 'INa.q', 'IK.p'],
53-
# inputs=[hh.input, I, 'iter'],
54-
# )
55-
# runner.run(length)
56-
#
57-
# bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
46+
hh = HH(1)
47+
I, length = bp.inputs.section_input(values=[0, 5, 0],
48+
durations=[100, 500, 100],
49+
return_length=True)
50+
runner = bp.DSRunner(
51+
hh,
52+
monitors=['V', 'INa.p', 'INa.q', 'IK.p'],
53+
inputs=[hh.input, I, 'iter'],
54+
)
55+
runner.run(length)
56+
57+
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
5858

0 commit comments

Comments
 (0)