Skip to content

Commit 65baea4

Browse files
authored
[math] Replace math operators with braintaichi (#698)
* Fix test bug * Update ad_support.py * Update test_activation.py * Skip test * Replace math operators with `braintaichi` * Update CI.yml * Update * Fix bugs * Update * Fix bugs
1 parent 26cce53 commit 65baea4

35 files changed

+177
-4507
lines changed

brainpy/_src/dependency_check.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
__all__ = [
77
'import_taichi',
88
'raise_taichi_not_found',
9+
'import_braintaichi',
10+
'raise_braintaichi_not_found',
911
'import_numba',
1012
'raise_numba_not_found',
1113
'import_cupy',
@@ -16,10 +18,11 @@
1618
]
1719

1820
_minimal_brainpylib_version = '0.2.6'
19-
_minimal_taichi_version = (1, 7, 0)
21+
_minimal_taichi_version = (1, 7, 2)
2022

2123
numba = None
2224
taichi = None
25+
braintaichi = None
2326
cupy = None
2427
cupy_jit = None
2528
brainpylib_cpu_ops = None
@@ -33,6 +36,10 @@
3336
cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
3437
'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n'
3538
'For CUDA v12.x > pip install cupy-cuda12x\n')
39+
braintaichi_install_info = ('We need braintaichi. Please install braintaichi by pip . \n'
40+
'> pip install braintaichi -U')
41+
42+
3643
os.environ["TI_LOG_LEVEL"] = "error"
3744

3845

@@ -69,6 +76,26 @@ def import_taichi(error_if_not_found=True):
6976
def raise_taichi_not_found(*args, **kwargs):
7077
raise ModuleNotFoundError(taichi_install_info)
7178

79+
def import_braintaichi(error_if_not_found=True):
80+
"""Internal API to import braintaichi.
81+
82+
If braintaichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
83+
otherwise it will return None.
84+
"""
85+
global braintaichi
86+
if braintaichi is None:
87+
try:
88+
import braintaichi as braintaichi
89+
except ModuleNotFoundError:
90+
if error_if_not_found:
91+
raise_braintaichi_not_found()
92+
else:
93+
return None
94+
return braintaichi
95+
96+
def raise_braintaichi_not_found():
97+
raise ModuleNotFoundError(braintaichi_install_info)
98+
7299

73100
def import_numba(error_if_not_found=True):
74101
"""

brainpy/_src/dnn/linear.py

+47-8
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from brainpy import math as bm
1212
from brainpy._src import connect, initialize as init
1313
from brainpy._src.context import share
14-
from brainpy._src.dependency_check import import_taichi
14+
from brainpy._src.dependency_check import import_taichi, import_braintaichi
1515
from brainpy._src.dnn.base import Layer
1616
from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP
1717
from brainpy.check import is_initializer
@@ -20,6 +20,7 @@
2020
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
2121
from brainpy.types import ArrayType, Sharding
2222

23+
bti = import_braintaichi(error_if_not_found=False)
2324
ti = import_taichi(error_if_not_found=False)
2425

2526
__all__ = [
@@ -238,7 +239,7 @@ def update(self, x):
238239
return x
239240

240241

241-
if ti is not None:
242+
if ti is not None and bti is not None:
242243

243244
# @numba.njit(nogil=True, fastmath=True, parallel=False)
244245
# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
@@ -273,7 +274,7 @@ def _dense_on_post(
273274
out_w[i, j] = old_w[i, j]
274275

275276

276-
dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post)
277+
dense_on_post_prim = bti.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post)
277278

278279

279280
# @numba.njit(nogil=True, fastmath=True, parallel=False)
@@ -309,7 +310,7 @@ def _dense_on_pre(
309310
out_w[i, j] = old_w[i, j]
310311

311312

312-
dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre)
313+
dense_on_pre_prim = bti.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre)
313314

314315
else:
315316
dense_on_pre_prim = None
@@ -326,6 +327,12 @@ def dense_on_pre(weight, spike, trace, w_min, w_max):
326327
w_max = np.inf
327328
w_min = jnp.atleast_1d(w_min)
328329
w_max = jnp.atleast_1d(w_max)
330+
331+
weight = bm.as_jax(weight)
332+
spike = bm.as_jax(spike)
333+
trace = bm.as_jax(trace)
334+
w_min = bm.as_jax(w_min)
335+
w_max = bm.as_jax(w_max)
329336
return dense_on_pre_prim(weight, spike, trace, w_min, w_max,
330337
outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
331338

@@ -340,6 +347,12 @@ def dense_on_post(weight, spike, trace, w_min, w_max):
340347
w_max = np.inf
341348
w_min = jnp.atleast_1d(w_min)
342349
w_max = jnp.atleast_1d(w_max)
350+
351+
weight = bm.as_jax(weight)
352+
spike = bm.as_jax(spike)
353+
trace = bm.as_jax(trace)
354+
w_min = bm.as_jax(w_min)
355+
w_max = bm.as_jax(w_max)
343356
return dense_on_post_prim(weight, spike, trace, w_min, w_max,
344357
outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
345358

@@ -735,7 +748,7 @@ def _csr_on_pre_update(
735748
out_w[i_syn] = old_w[i_syn]
736749

737750

738-
csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update)
751+
csr_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update)
739752

740753

741754
@ti.kernel
@@ -759,7 +772,7 @@ def _coo_on_pre_update(
759772
out_w[i_syn] = old_w[i_syn]
760773

761774

762-
coo_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update)
775+
coo_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update)
763776

764777

765778
@ti.kernel
@@ -783,7 +796,7 @@ def _coo_on_post_update(
783796
out_w[i_syn] = old_w[i_syn]
784797

785798

786-
coo_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update)
799+
coo_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update)
787800

788801

789802
# @numba.njit(nogil=True, fastmath=True, parallel=False)
@@ -824,7 +837,7 @@ def _csc_on_post_update(
824837
out_w[i_syn] = old_w[i_syn]
825838

826839

827-
csc_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update)
840+
csc_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update)
828841

829842

830843
else:
@@ -843,6 +856,14 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
843856
w_max = np.inf
844857
w_min = jnp.atleast_1d(w_min)
845858
w_max = jnp.atleast_1d(w_max)
859+
860+
w = bm.as_jax(w)
861+
indices = bm.as_jax(indices)
862+
indptr = bm.as_jax(indptr)
863+
spike = bm.as_jax(spike)
864+
trace = bm.as_jax(trace)
865+
w_min = bm.as_jax(w_min)
866+
w_max = bm.as_jax(w_max)
846867
return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max,
847868
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
848869

@@ -857,6 +878,15 @@ def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None
857878
w_max = np.inf
858879
w_min = jnp.atleast_1d(w_min)
859880
w_max = jnp.atleast_1d(w_max)
881+
882+
w = bm.as_jax(w)
883+
pre_ids = bm.as_jax(pre_ids)
884+
post_ids = bm.as_jax(post_ids)
885+
spike = bm.as_jax(spike)
886+
trace = bm.as_jax(trace)
887+
w_min = bm.as_jax(w_min)
888+
w_max = bm.as_jax(w_max)
889+
860890
return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max,
861891
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
862892

@@ -871,6 +901,15 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=
871901
w_max = np.inf
872902
w_min = jnp.atleast_1d(w_min)
873903
w_max = jnp.atleast_1d(w_max)
904+
905+
w = bm.as_jax(w)
906+
post_ids = bm.as_jax(post_ids)
907+
indptr = bm.as_jax(indptr)
908+
w_ids = bm.as_jax(w_ids)
909+
post_spike = bm.as_jax(post_spike)
910+
pre_trace = bm.as_jax(pre_trace)
911+
w_min = bm.as_jax(w_min)
912+
w_max = bm.as_jax(w_max)
874913
return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max,
875914
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
876915

brainpy/_src/dnn/tests/test_linear.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import pytest
22
from absl.testing import absltest
33
from absl.testing import parameterized
4+
import jax.numpy as jnp
45

56
import brainpy as bp
67
import brainpy.math as bm
78

8-
from brainpy._src.dependency_check import import_taichi
9-
10-
if import_taichi(error_if_not_found=False) is None:
11-
pytest.skip('no taichi', allow_module_level=True)
129

1310

1411
class TestLinear(parameterized.TestCase):
@@ -104,11 +101,11 @@ def test_CSRLinear(self, conn):
104101
bm.random.seed()
105102
f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal())
106103
x = bm.random.random((16, 100))
107-
y = f(x)
104+
y = f(jnp.asarray(x))
108105
self.assertTrue(y.shape == (16, 100))
109106

110107
x = bm.random.random((100,))
111-
y = f(x)
108+
y = f(jnp.asarray(x))
112109
self.assertTrue(y.shape == (100,))
113110
bm.clear_buffer_memory()
114111

@@ -123,10 +120,10 @@ def test_EventCSRLinear(self, conn):
123120
bm.random.seed()
124121
f = bp.layers.EventCSRLinear(conn, weight=bp.init.Normal())
125122
x = bm.random.random((16, 100))
126-
y = f(x)
123+
y = f(jnp.asarray(x))
127124
self.assertTrue(y.shape == (16, 100))
128125
x = bm.random.random((100,))
129-
y = f(x)
126+
y = f(jnp.asarray(x))
130127
self.assertTrue(y.shape == (100,))
131128
bm.clear_buffer_memory()
132129

brainpy/_src/dnn/tests/test_mode.py

-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44

55
import brainpy as bp
66
import brainpy.math as bm
7-
from brainpy._src.dependency_check import import_taichi
8-
9-
if import_taichi(error_if_not_found=False) is None:
10-
pytest.skip('no taichi', allow_module_level=True)
117

128

139
class Test_Conv(parameterized.TestCase):

brainpy/_src/dyn/projections/tests/test_STDP.py

-4
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66

77
import brainpy as bp
88
import brainpy.math as bm
9-
from brainpy._src.dependency_check import import_taichi
10-
11-
if import_taichi(error_if_not_found=False) is None:
12-
pytest.skip('no taichi', allow_module_level=True)
139

1410
bm.set_platform('cpu')
1511

brainpy/_src/dyn/projections/tests/test_aligns.py

-4
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55
import brainpy as bp
66
import brainpy.math as bm
77

8-
from brainpy._src.dependency_check import import_taichi
9-
10-
if import_taichi(error_if_not_found=False) is None:
11-
pytest.skip('no taichi', allow_module_level=True)
128

139
neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
1410
V_initializer=bp.init.Normal(-55., 2.))

brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py

-4
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
import brainpy as bp
88
import brainpy.math as bm
99
from brainpy._src.dynold.synapses import abstract_models
10-
from brainpy._src.dependency_check import import_taichi
11-
12-
if import_taichi(error_if_not_found=False) is None:
13-
pytest.skip('no taichi', allow_module_level=True)
1410

1511

1612
class Test_Abstract_Synapse(parameterized.TestCase):

brainpy/_src/dynold/synapses/tests/test_biological_synapses.py

-4
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66
import brainpy as bp
77
import brainpy.math as bm
88

9-
from brainpy._src.dependency_check import import_taichi
10-
11-
if import_taichi(error_if_not_found=False) is None:
12-
pytest.skip('no taichi', allow_module_level=True)
139

1410
biological_models = [
1511
bp.synapses.AMPA,

brainpy/_src/math/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from .compat_numpy import *
4545
from .compat_tensorflow import *
4646
from .others import *
47-
from . import random, linalg, fft, tifunc
47+
from . import random, linalg, fft
4848

4949
# operators
5050
from .op_register import *

0 commit comments

Comments
 (0)