11
11
from brainpy import math as bm
12
12
from brainpy ._src import connect , initialize as init
13
13
from brainpy ._src .context import share
14
- from brainpy ._src .dependency_check import import_taichi
14
+ from brainpy ._src .dependency_check import import_taichi , import_braintaichi
15
15
from brainpy ._src .dnn .base import Layer
16
16
from brainpy ._src .mixin import SupportOnline , SupportOffline , SupportSTDP
17
17
from brainpy .check import is_initializer
20
20
from brainpy .initialize import XavierNormal , ZeroInit , Initializer , parameter
21
21
from brainpy .types import ArrayType , Sharding
22
22
23
+ bti = import_braintaichi (error_if_not_found = False )
23
24
ti = import_taichi (error_if_not_found = False )
24
25
25
26
__all__ = [
@@ -238,7 +239,7 @@ def update(self, x):
238
239
return x
239
240
240
241
241
- if ti is not None :
242
+ if ti is not None and bti is not None :
242
243
243
244
# @numba.njit(nogil=True, fastmath=True, parallel=False)
244
245
# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
@@ -273,7 +274,7 @@ def _dense_on_post(
273
274
out_w [i , j ] = old_w [i , j ]
274
275
275
276
276
- dense_on_post_prim = bm .XLACustomOp (cpu_kernel = _dense_on_post , gpu_kernel = _dense_on_post )
277
+ dense_on_post_prim = bti .XLACustomOp (cpu_kernel = _dense_on_post , gpu_kernel = _dense_on_post )
277
278
278
279
279
280
# @numba.njit(nogil=True, fastmath=True, parallel=False)
@@ -309,7 +310,7 @@ def _dense_on_pre(
309
310
out_w [i , j ] = old_w [i , j ]
310
311
311
312
312
- dense_on_pre_prim = bm .XLACustomOp (cpu_kernel = _dense_on_pre , gpu_kernel = _dense_on_pre )
313
+ dense_on_pre_prim = bti .XLACustomOp (cpu_kernel = _dense_on_pre , gpu_kernel = _dense_on_pre )
313
314
314
315
else :
315
316
dense_on_pre_prim = None
@@ -326,6 +327,12 @@ def dense_on_pre(weight, spike, trace, w_min, w_max):
326
327
w_max = np .inf
327
328
w_min = jnp .atleast_1d (w_min )
328
329
w_max = jnp .atleast_1d (w_max )
330
+
331
+ weight = bm .as_jax (weight )
332
+ spike = bm .as_jax (spike )
333
+ trace = bm .as_jax (trace )
334
+ w_min = bm .as_jax (w_min )
335
+ w_max = bm .as_jax (w_max )
329
336
return dense_on_pre_prim (weight , spike , trace , w_min , w_max ,
330
337
outs = [jax .ShapeDtypeStruct (weight .shape , weight .dtype )])[0 ]
331
338
@@ -340,6 +347,12 @@ def dense_on_post(weight, spike, trace, w_min, w_max):
340
347
w_max = np .inf
341
348
w_min = jnp .atleast_1d (w_min )
342
349
w_max = jnp .atleast_1d (w_max )
350
+
351
+ weight = bm .as_jax (weight )
352
+ spike = bm .as_jax (spike )
353
+ trace = bm .as_jax (trace )
354
+ w_min = bm .as_jax (w_min )
355
+ w_max = bm .as_jax (w_max )
343
356
return dense_on_post_prim (weight , spike , trace , w_min , w_max ,
344
357
outs = [jax .ShapeDtypeStruct (weight .shape , weight .dtype )])[0 ]
345
358
@@ -735,7 +748,7 @@ def _csr_on_pre_update(
735
748
out_w [i_syn ] = old_w [i_syn ]
736
749
737
750
738
- csr_on_pre_update_prim = bm .XLACustomOp (cpu_kernel = _csr_on_pre_update , gpu_kernel = _csr_on_pre_update )
751
+ csr_on_pre_update_prim = bti .XLACustomOp (cpu_kernel = _csr_on_pre_update , gpu_kernel = _csr_on_pre_update )
739
752
740
753
741
754
@ti .kernel
@@ -759,7 +772,7 @@ def _coo_on_pre_update(
759
772
out_w [i_syn ] = old_w [i_syn ]
760
773
761
774
762
- coo_on_pre_update_prim = bm .XLACustomOp (cpu_kernel = _coo_on_pre_update , gpu_kernel = _coo_on_pre_update )
775
+ coo_on_pre_update_prim = bti .XLACustomOp (cpu_kernel = _coo_on_pre_update , gpu_kernel = _coo_on_pre_update )
763
776
764
777
765
778
@ti .kernel
@@ -783,7 +796,7 @@ def _coo_on_post_update(
783
796
out_w [i_syn ] = old_w [i_syn ]
784
797
785
798
786
- coo_on_post_update_prim = bm .XLACustomOp (cpu_kernel = _coo_on_post_update , gpu_kernel = _coo_on_post_update )
799
+ coo_on_post_update_prim = bti .XLACustomOp (cpu_kernel = _coo_on_post_update , gpu_kernel = _coo_on_post_update )
787
800
788
801
789
802
# @numba.njit(nogil=True, fastmath=True, parallel=False)
@@ -824,7 +837,7 @@ def _csc_on_post_update(
824
837
out_w [i_syn ] = old_w [i_syn ]
825
838
826
839
827
- csc_on_post_update_prim = bm .XLACustomOp (cpu_kernel = _csc_on_post_update , gpu_kernel = _csc_on_post_update )
840
+ csc_on_post_update_prim = bti .XLACustomOp (cpu_kernel = _csc_on_post_update , gpu_kernel = _csc_on_post_update )
828
841
829
842
830
843
else :
@@ -843,6 +856,14 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
843
856
w_max = np .inf
844
857
w_min = jnp .atleast_1d (w_min )
845
858
w_max = jnp .atleast_1d (w_max )
859
+
860
+ w = bm .as_jax (w )
861
+ indices = bm .as_jax (indices )
862
+ indptr = bm .as_jax (indptr )
863
+ spike = bm .as_jax (spike )
864
+ trace = bm .as_jax (trace )
865
+ w_min = bm .as_jax (w_min )
866
+ w_max = bm .as_jax (w_max )
846
867
return csr_on_pre_update_prim (w , indices , indptr , spike , trace , w_min , w_max ,
847
868
outs = [jax .ShapeDtypeStruct (w .shape , w .dtype )])[0 ]
848
869
@@ -857,6 +878,15 @@ def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None
857
878
w_max = np .inf
858
879
w_min = jnp .atleast_1d (w_min )
859
880
w_max = jnp .atleast_1d (w_max )
881
+
882
+ w = bm .as_jax (w )
883
+ pre_ids = bm .as_jax (pre_ids )
884
+ post_ids = bm .as_jax (post_ids )
885
+ spike = bm .as_jax (spike )
886
+ trace = bm .as_jax (trace )
887
+ w_min = bm .as_jax (w_min )
888
+ w_max = bm .as_jax (w_max )
889
+
860
890
return coo_on_pre_update_prim (w , pre_ids , post_ids , spike , trace , w_min , w_max ,
861
891
outs = [jax .ShapeDtypeStruct (w .shape , w .dtype )])[0 ]
862
892
@@ -871,6 +901,15 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=
871
901
w_max = np .inf
872
902
w_min = jnp .atleast_1d (w_min )
873
903
w_max = jnp .atleast_1d (w_max )
904
+
905
+ w = bm .as_jax (w )
906
+ post_ids = bm .as_jax (post_ids )
907
+ indptr = bm .as_jax (indptr )
908
+ w_ids = bm .as_jax (w_ids )
909
+ post_spike = bm .as_jax (post_spike )
910
+ pre_trace = bm .as_jax (pre_trace )
911
+ w_min = bm .as_jax (w_min )
912
+ w_max = bm .as_jax (w_max )
874
913
return csc_on_post_update_prim (w , post_ids , indptr , w_ids , post_spike , pre_trace , w_min , w_max ,
875
914
outs = [jax .ShapeDtypeStruct (w .shape , w .dtype )])[0 ]
876
915
0 commit comments