Skip to content

Commit 165b9ee

Browse files
Find back updates (brainpy#646)
* roll back previous updates * update JIT transform * Update * Update delayvars.py * remove using internal API of jax for registering pytree objects --------- Co-authored-by: He Sichao <[email protected]>
1 parent 5e54107 commit 165b9ee

24 files changed

+610
-729
lines changed

brainpy/_src/dnn/conv.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def update(self, x):
160160
nonbatching = False
161161
if x.ndim == self.num_spatial_dims + 1:
162162
nonbatching = True
163-
x = x.unsqueeze(0)
163+
x = bm.unsqueeze(x, 0)
164164
w = self.w.value
165165
if self.mask is not None:
166166
try:
@@ -190,6 +190,9 @@ def __repr__(self):
190190
class Conv1d(_GeneralConv):
191191
"""One-dimensional convolution.
192192
193+
The input should a 2d array with the shape of ``[H, C]``, or
194+
a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size.
195+
193196
Parameters
194197
----------
195198
in_channels: int
@@ -282,6 +285,9 @@ def _check_input_dim(self, x):
282285
class Conv2d(_GeneralConv):
283286
"""Two-dimensional convolution.
284287
288+
The input should a 3d array with the shape of ``[H, W, C]``, or
289+
a 4d array with the shape of ``[B, H, W, C]``.
290+
285291
Parameters
286292
----------
287293
in_channels: int
@@ -375,6 +381,9 @@ def _check_input_dim(self, x):
375381
class Conv3d(_GeneralConv):
376382
"""Three-dimensional convolution.
377383
384+
The input should a 3d array with the shape of ``[H, W, D, C]``, or
385+
a 4d array with the shape of ``[B, H, W, D, C]``.
386+
378387
Parameters
379388
----------
380389
in_channels: int

brainpy/_src/dnn/tests/test_activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from absl.testing import parameterized
21
from absl.testing import absltest
2+
from absl.testing import parameterized
33
import brainpy as bp
44
import brainpy.math as bm
55

brainpy/_src/dnn/tests/test_conv_layers.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
# -*- coding: utf-8 -*-
22

3-
from unittest import TestCase
4-
from absl.testing import absltest
53
import jax.numpy as jnp
6-
import brainpy.math as bm
4+
from absl.testing import absltest
75
from absl.testing import parameterized
6+
87
import brainpy as bp
98
import brainpy.math as bm
109

1110

1211
class TestConv(parameterized.TestCase):
1312
def test_Conv2D_img(self):
14-
bm.random.seed()
1513
img = jnp.zeros((2, 200, 198, 4))
1614
for k in range(4):
1715
x = 30 + 60 * k
@@ -24,21 +22,22 @@ def test_Conv2D_img(self):
2422
strides=(2, 1), padding='VALID', groups=4)
2523
out = net(img)
2624
print("out shape: ", out.shape)
25+
self.assertEqual(out.shape, (2, 99, 196, 32))
2726
# print("First output channel:")
2827
# plt.figure(figsize=(10, 10))
2928
# plt.imshow(np.array(img)[0, :, :, 0])
3029
# plt.show()
3130
bm.clear_buffer_memory()
3231

3332
def test_conv1D(self):
34-
bm.random.seed()
3533
with bp.math.training_environment():
3634
model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,))
3735

3836
input = bp.math.ones((2, 5, 3))
3937

4038
out = model(input)
4139
print("out shape: ", out.shape)
40+
self.assertEqual(out.shape, (2, 5, 32))
4241
# print("First output channel:")
4342
# plt.figure(figsize=(10, 10))
4443
# plt.imshow(np.array(out)[0, :, :])
@@ -54,6 +53,7 @@ def test_conv2D(self):
5453

5554
out = model(input)
5655
print("out shape: ", out.shape)
56+
self.assertEqual(out.shape, (2, 5, 5, 32))
5757
# print("First output channel:")
5858
# plt.figure(figsize=(10, 10))
5959
# plt.imshow(np.array(out)[0, :, :, 31])
@@ -67,6 +67,7 @@ def test_conv3D(self):
6767
input = bp.math.ones((2, 5, 5, 5, 3))
6868
out = model(input)
6969
print("out shape: ", out.shape)
70+
self.assertEqual(out.shape, (2, 5, 5, 5, 32))
7071
bm.clear_buffer_memory()
7172

7273

brainpy/_src/dnn/tests/test_function.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
# -*- coding: utf-8 -*-
22

3-
from unittest import TestCase
4-
5-
import jax.numpy as jnp
6-
import brainpy.math as bm
73
from absl.testing import absltest
84
from absl.testing import parameterized
5+
96
import brainpy as bp
7+
import brainpy.math as bm
108

119

1210
class TestFunction(parameterized.TestCase):

brainpy/_src/dnn/tests/test_normalization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import brainpy.math as bm
21
from absl.testing import parameterized
32
from absl.testing import absltest
3+
44
import brainpy as bp
5+
import brainpy.math as bm
56

67

78
class Test_Normalization(parameterized.TestCase):

brainpy/_src/dnn/tests/test_pooling_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import jax
44
import jax.numpy as jnp
55
import numpy as np
6-
from absl.testing import parameterized
76
from absl.testing import absltest
7+
from absl.testing import parameterized
88

99
import brainpy as bp
1010
import brainpy.math as bm

brainpy/_src/math/delayvars.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from brainpy import check
1212
from brainpy.check import is_float, is_integer, jit_error
1313
from brainpy.errors import UnsupportedError
14-
from .compat_numpy import vstack, broadcast_to
14+
from .compat_numpy import broadcast_to, expand_dims, concatenate
1515
from .environment import get_dt, get_float
1616
from .interoperability import as_jax
1717
from .ndarray import ndarray, Array
@@ -392,6 +392,7 @@ def reset(
392392
dtype=delay_target.dtype),
393393
batch_axis=batch_axis)
394394
else:
395+
self.data.value
395396
self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape,
396397
dtype=delay_target.dtype)
397398

@@ -472,7 +473,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None):
472473

473474
elif self.update_method == CONCAT_UPDATE:
474475
if self.num_delay_step >= 2:
475-
self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]])
476+
self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0)
476477
else:
477478
self.data[:] = value
478479

brainpy/_src/math/object_transform/autograd.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@
2828
get_stack_cache,
2929
cache_stack)
3030
from .base import (BrainPyObject, ObjectTransform)
31-
from .variables import (Variable,
32-
VariableStack,
33-
current_transform_number,
34-
new_transform)
31+
from .variables import (Variable, VariableStack)
32+
from .tools import eval_shape
3533

3634
__all__ = [
3735
'grad', # gradient of scalar function
@@ -203,36 +201,21 @@ def __call__(self, *args, **kwargs):
203201
elif not self._eval_dyn_vars: # evaluate dynamical variables
204202
stack = get_stack_cache(self.target)
205203
if stack is None:
206-
with new_transform(self):
207-
with VariableStack() as stack:
208-
if current_transform_number() > 1:
209-
rets = self._transform(
210-
[v.value for v in self._grad_vars], # variables for gradients
211-
{}, # dynamical variables
212-
*args,
213-
**kwargs
214-
)
215-
else:
216-
rets = jax.eval_shape(
217-
self._transform,
218-
[v.value for v in self._grad_vars], # variables for gradients
219-
{}, # dynamical variables
220-
*args,
221-
**kwargs
222-
)
204+
with VariableStack() as stack:
205+
rets = eval_shape(self._transform,
206+
[v.value for v in self._grad_vars], # variables for gradients
207+
{}, # dynamical variables
208+
*args,
209+
**kwargs)
223210
cache_stack(self.target, stack)
224211

225-
self._dyn_vars = stack
226-
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
227-
self._eval_dyn_vars = True
212+
self._dyn_vars = stack
213+
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
214+
self._eval_dyn_vars = True
228215

229-
# if not the outermost transformation
230-
if current_transform_number():
231-
return self._return(rets)
232-
else:
233-
self._dyn_vars = stack
234-
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
235-
self._eval_dyn_vars = True
216+
# if not the outermost transformation
217+
if not stack.is_first_stack():
218+
return self._return(rets)
236219

237220
rets = self._transform(
238221
[v.value for v in self._grad_vars], # variables for gradients

brainpy/_src/math/object_transform/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import jax
1414
import numpy as np
15-
from jax._src.tree_util import _registry
1615
from jax.tree_util import register_pytree_node_class
1716

1817
from brainpy._src.math.modes import Mode
@@ -27,6 +26,8 @@
2726

2827
variable_ = None
2928
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
29+
registered = set()
30+
3031

3132
__all__ = [
3233
'BrainPyObject', 'Base', 'FunAsObject', 'ObjectTransform',
@@ -91,8 +92,9 @@ def __init__(self, name=None):
9192
super().__init__()
9293

9394
if defaults.bp_object_as_pytree:
94-
if self.__class__ not in _registry:
95+
if self.__class__ not in registered:
9596
register_pytree_node_class(self.__class__)
97+
registered.add(self.__class__)
9698

9799
# check whether the object has a unique name.
98100
self._name = None

0 commit comments

Comments
 (0)