@@ -740,41 +740,43 @@ def check_allocs_in_fgraph(fgraph, n):
740740 def setup_method (self ):
741741 self .rng = np .random .default_rng (seed = utt .fetch_seed ())
742742
743- def test_alloc_constant_folding (self ):
743+ @pytest .mark .parametrize (
744+ "subtensor_fn, expected_grad_n_alloc" ,
745+ [
746+ # IncSubtensor1
747+ (lambda x : x [:60 ], 1 ),
748+ # AdvancedIncSubtensor1
749+ (lambda x : x [np .arange (60 )], 1 ),
750+ # AdvancedIncSubtensor
751+ (lambda x : x [np .arange (50 ), np .arange (50 )], 1 ),
752+ ],
753+ )
754+ def test_alloc_constant_folding (self , subtensor_fn , expected_grad_n_alloc ):
744755 test_params = np .asarray (self .rng .standard_normal (50 * 60 ), self .dtype )
745756
746757 some_vector = vector ("some_vector" , dtype = self .dtype )
747758 some_matrix = some_vector .reshape ((60 , 50 ))
748759 variables = self .shared (np .ones ((50 ,), dtype = self .dtype ))
749- idx = constant (np .arange (50 ))
750-
751- for alloc_ , (subtensor , n_alloc ) in zip (
752- self .allocs ,
753- [
754- # IncSubtensor1
755- (some_matrix [:60 ], 2 ),
756- # AdvancedIncSubtensor1
757- (some_matrix [arange (60 )], 2 ),
758- # AdvancedIncSubtensor
759- (some_matrix [idx , idx ], 1 ),
760- ],
761- strict = True ,
762- ):
763- derp = pt_sum (dense_dot (subtensor , variables ))
764760
765- fobj = pytensor .function ([some_vector ], derp , mode = self .mode )
766- grad_derp = pytensor .grad (derp , some_vector )
767- fgrad = pytensor .function ([some_vector ], grad_derp , mode = self .mode )
761+ subtensor = subtensor_fn (some_matrix )
768762
769- topo_obj = fobj .maker .fgraph .toposort ()
770- assert sum (isinstance (node .op , type (alloc_ )) for node in topo_obj ) == 0
763+ derp = pt_sum (dense_dot (subtensor , variables ))
764+ fobj = pytensor .function ([some_vector ], derp , mode = self .mode )
765+ assert (
766+ sum (isinstance (node .op , Alloc ) for node in fobj .maker .fgraph .apply_nodes )
767+ == 0
768+ )
769+ # TODO: Assert something about the value if we bothered to call it?
770+ fobj (test_params )
771771
772- topo_grad = fgrad .maker .fgraph .toposort ()
773- assert (
774- sum (isinstance (node .op , type (alloc_ )) for node in topo_grad ) == n_alloc
775- ), (alloc_ , subtensor , n_alloc , topo_grad )
776- fobj (test_params )
777- fgrad (test_params )
772+ grad_derp = pytensor .grad (derp , some_vector )
773+ fgrad = pytensor .function ([some_vector ], grad_derp , mode = self .mode )
774+ assert (
775+ sum (isinstance (node .op , Alloc ) for node in fgrad .maker .fgraph .apply_nodes )
776+ == expected_grad_n_alloc
777+ )
778+ # TODO: Assert something about the value if we bothered to call it?
779+ fgrad (test_params )
778780
779781 def test_alloc_output (self ):
780782 val = constant (self .rng .standard_normal ((1 , 1 )), dtype = self .dtype )
0 commit comments