Skip to content

Commit 2e2f073

Browse files
committed
Refactor test and change expected counts of Alloc that were due to BlasOpt
1 parent aa294ba commit 2e2f073

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

tests/tensor/test_basic.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -740,41 +740,43 @@ def check_allocs_in_fgraph(fgraph, n):
740740
def setup_method(self):
741741
self.rng = np.random.default_rng(seed=utt.fetch_seed())
742742

743-
def test_alloc_constant_folding(self):
743+
@pytest.mark.parametrize(
744+
"subtensor_fn, expected_grad_n_alloc",
745+
[
746+
# IncSubtensor1
747+
(lambda x: x[:60], 1),
748+
# AdvancedIncSubtensor1
749+
(lambda x: x[np.arange(60)], 1),
750+
# AdvancedIncSubtensor
751+
(lambda x: x[np.arange(50), np.arange(50)], 1),
752+
],
753+
)
754+
def test_alloc_constant_folding(self, subtensor_fn, expected_grad_n_alloc):
744755
test_params = np.asarray(self.rng.standard_normal(50 * 60), self.dtype)
745756

746757
some_vector = vector("some_vector", dtype=self.dtype)
747758
some_matrix = some_vector.reshape((60, 50))
748759
variables = self.shared(np.ones((50,), dtype=self.dtype))
749-
idx = constant(np.arange(50))
750-
751-
for alloc_, (subtensor, n_alloc) in zip(
752-
self.allocs,
753-
[
754-
# IncSubtensor1
755-
(some_matrix[:60], 2),
756-
# AdvancedIncSubtensor1
757-
(some_matrix[arange(60)], 2),
758-
# AdvancedIncSubtensor
759-
(some_matrix[idx, idx], 1),
760-
],
761-
strict=True,
762-
):
763-
derp = pt_sum(dense_dot(subtensor, variables))
764760

765-
fobj = pytensor.function([some_vector], derp, mode=self.mode)
766-
grad_derp = pytensor.grad(derp, some_vector)
767-
fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode)
761+
subtensor = subtensor_fn(some_matrix)
768762

769-
topo_obj = fobj.maker.fgraph.toposort()
770-
assert sum(isinstance(node.op, type(alloc_)) for node in topo_obj) == 0
763+
derp = pt_sum(dense_dot(subtensor, variables))
764+
fobj = pytensor.function([some_vector], derp, mode=self.mode)
765+
assert (
766+
sum(isinstance(node.op, Alloc) for node in fobj.maker.fgraph.apply_nodes)
767+
== 0
768+
)
769+
# TODO: Assert something about the value if we bothered to call it?
770+
fobj(test_params)
771771

772-
topo_grad = fgrad.maker.fgraph.toposort()
773-
assert (
774-
sum(isinstance(node.op, type(alloc_)) for node in topo_grad) == n_alloc
775-
), (alloc_, subtensor, n_alloc, topo_grad)
776-
fobj(test_params)
777-
fgrad(test_params)
772+
grad_derp = pytensor.grad(derp, some_vector)
773+
fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode)
774+
assert (
775+
sum(isinstance(node.op, Alloc) for node in fgrad.maker.fgraph.apply_nodes)
776+
== expected_grad_n_alloc
777+
)
778+
# TODO: Assert something about the value if we bothered to call it?
779+
fgrad(test_params)
778780

779781
def test_alloc_output(self):
780782
val = constant(self.rng.standard_normal((1, 1)), dtype=self.dtype)

0 commit comments

Comments
 (0)