1111from brainpy import math as bm
1212from brainpy ._src import connect , initialize as init
1313from brainpy ._src .context import share
14- from brainpy ._src .dependency_check import import_taichi
14+ from brainpy ._src .dependency_check import import_taichi , import_braintaichi
1515from brainpy ._src .dnn .base import Layer
1616from brainpy ._src .mixin import SupportOnline , SupportOffline , SupportSTDP
1717from brainpy .check import is_initializer
2020from brainpy .initialize import XavierNormal , ZeroInit , Initializer , parameter
2121from brainpy .types import ArrayType , Sharding
2222
23+ bti = import_braintaichi (error_if_not_found = False )
2324ti = import_taichi (error_if_not_found = False )
2425
2526__all__ = [
@@ -238,7 +239,7 @@ def update(self, x):
238239 return x
239240
240241
241- if ti is not None :
242+ if ti is not None and bti is not None :
242243
243244 # @numba.njit(nogil=True, fastmath=True, parallel=False)
244245 # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
@@ -273,7 +274,7 @@ def _dense_on_post(
273274 out_w [i , j ] = old_w [i , j ]
274275
275276
276- dense_on_post_prim = bm .XLACustomOp (cpu_kernel = _dense_on_post , gpu_kernel = _dense_on_post )
277+ dense_on_post_prim = bti .XLACustomOp (cpu_kernel = _dense_on_post , gpu_kernel = _dense_on_post )
277278
278279
279280 # @numba.njit(nogil=True, fastmath=True, parallel=False)
@@ -309,7 +310,7 @@ def _dense_on_pre(
309310 out_w [i , j ] = old_w [i , j ]
310311
311312
312- dense_on_pre_prim = bm .XLACustomOp (cpu_kernel = _dense_on_pre , gpu_kernel = _dense_on_pre )
313+ dense_on_pre_prim = bti .XLACustomOp (cpu_kernel = _dense_on_pre , gpu_kernel = _dense_on_pre )
313314
314315else :
315316 dense_on_pre_prim = None
@@ -326,6 +327,12 @@ def dense_on_pre(weight, spike, trace, w_min, w_max):
326327 w_max = np .inf
327328 w_min = jnp .atleast_1d (w_min )
328329 w_max = jnp .atleast_1d (w_max )
330+
331+ weight = bm .as_jax (weight )
332+ spike = bm .as_jax (spike )
333+ trace = bm .as_jax (trace )
334+ w_min = bm .as_jax (w_min )
335+ w_max = bm .as_jax (w_max )
329336 return dense_on_pre_prim (weight , spike , trace , w_min , w_max ,
330337 outs = [jax .ShapeDtypeStruct (weight .shape , weight .dtype )])[0 ]
331338
@@ -340,6 +347,12 @@ def dense_on_post(weight, spike, trace, w_min, w_max):
340347 w_max = np .inf
341348 w_min = jnp .atleast_1d (w_min )
342349 w_max = jnp .atleast_1d (w_max )
350+
351+ weight = bm .as_jax (weight )
352+ spike = bm .as_jax (spike )
353+ trace = bm .as_jax (trace )
354+ w_min = bm .as_jax (w_min )
355+ w_max = bm .as_jax (w_max )
343356 return dense_on_post_prim (weight , spike , trace , w_min , w_max ,
344357 outs = [jax .ShapeDtypeStruct (weight .shape , weight .dtype )])[0 ]
345358
@@ -735,7 +748,7 @@ def _csr_on_pre_update(
735748 out_w [i_syn ] = old_w [i_syn ]
736749
737750
738- csr_on_pre_update_prim = bm .XLACustomOp (cpu_kernel = _csr_on_pre_update , gpu_kernel = _csr_on_pre_update )
751+ csr_on_pre_update_prim = bti .XLACustomOp (cpu_kernel = _csr_on_pre_update , gpu_kernel = _csr_on_pre_update )
739752
740753
741754 @ti .kernel
@@ -759,7 +772,7 @@ def _coo_on_pre_update(
759772 out_w [i_syn ] = old_w [i_syn ]
760773
761774
762- coo_on_pre_update_prim = bm .XLACustomOp (cpu_kernel = _coo_on_pre_update , gpu_kernel = _coo_on_pre_update )
775+ coo_on_pre_update_prim = bti .XLACustomOp (cpu_kernel = _coo_on_pre_update , gpu_kernel = _coo_on_pre_update )
763776
764777
765778 @ti .kernel
@@ -783,7 +796,7 @@ def _coo_on_post_update(
783796 out_w [i_syn ] = old_w [i_syn ]
784797
785798
786- coo_on_post_update_prim = bm .XLACustomOp (cpu_kernel = _coo_on_post_update , gpu_kernel = _coo_on_post_update )
799+ coo_on_post_update_prim = bti .XLACustomOp (cpu_kernel = _coo_on_post_update , gpu_kernel = _coo_on_post_update )
787800
788801
789802 # @numba.njit(nogil=True, fastmath=True, parallel=False)
@@ -824,7 +837,7 @@ def _csc_on_post_update(
824837 out_w [i_syn ] = old_w [i_syn ]
825838
826839
827- csc_on_post_update_prim = bm .XLACustomOp (cpu_kernel = _csc_on_post_update , gpu_kernel = _csc_on_post_update )
840+ csc_on_post_update_prim = bti .XLACustomOp (cpu_kernel = _csc_on_post_update , gpu_kernel = _csc_on_post_update )
828841
829842
830843else :
@@ -843,6 +856,14 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
843856 w_max = np .inf
844857 w_min = jnp .atleast_1d (w_min )
845858 w_max = jnp .atleast_1d (w_max )
859+
860+ w = bm .as_jax (w )
861+ indices = bm .as_jax (indices )
862+ indptr = bm .as_jax (indptr )
863+ spike = bm .as_jax (spike )
864+ trace = bm .as_jax (trace )
865+ w_min = bm .as_jax (w_min )
866+ w_max = bm .as_jax (w_max )
846867 return csr_on_pre_update_prim (w , indices , indptr , spike , trace , w_min , w_max ,
847868 outs = [jax .ShapeDtypeStruct (w .shape , w .dtype )])[0 ]
848869
@@ -857,6 +878,15 @@ def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None
857878 w_max = np .inf
858879 w_min = jnp .atleast_1d (w_min )
859880 w_max = jnp .atleast_1d (w_max )
881+
882+ w = bm .as_jax (w )
883+ pre_ids = bm .as_jax (pre_ids )
884+ post_ids = bm .as_jax (post_ids )
885+ spike = bm .as_jax (spike )
886+ trace = bm .as_jax (trace )
887+ w_min = bm .as_jax (w_min )
888+ w_max = bm .as_jax (w_max )
889+
860890 return coo_on_pre_update_prim (w , pre_ids , post_ids , spike , trace , w_min , w_max ,
861891 outs = [jax .ShapeDtypeStruct (w .shape , w .dtype )])[0 ]
862892
@@ -871,6 +901,15 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=
871901 w_max = np .inf
872902 w_min = jnp .atleast_1d (w_min )
873903 w_max = jnp .atleast_1d (w_max )
904+
905+ w = bm .as_jax (w )
906+ post_ids = bm .as_jax (post_ids )
907+ indptr = bm .as_jax (indptr )
908+ w_ids = bm .as_jax (w_ids )
909+ post_spike = bm .as_jax (post_spike )
910+ pre_trace = bm .as_jax (pre_trace )
911+ w_min = bm .as_jax (w_min )
912+ w_max = bm .as_jax (w_max )
874913 return csc_on_post_update_prim (w , post_ids , indptr , w_ids , post_spike , pre_trace , w_min , w_max ,
875914 outs = [jax .ShapeDtypeStruct (w .shape , w .dtype )])[0 ]
876915
0 commit comments