@@ -758,41 +758,43 @@ def check_allocs_in_fgraph(fgraph, n):
758758 def setup_method (self ):
759759 self .rng = np .random .default_rng (seed = utt .fetch_seed ())
760760
761- def test_alloc_constant_folding (self ):
761+ @pytest .mark .parametrize (
762+ "subtensor_fn, expected_grad_n_alloc" ,
763+ [
764+ # IncSubtensor1
765+ (lambda x : x [:60 ], 1 ),
766+ # AdvancedIncSubtensor1
767+ (lambda x : x [np .arange (60 )], 1 ),
768+ # AdvancedIncSubtensor
769+ (lambda x : x [np .arange (50 ), np .arange (50 )], 1 ),
770+ ],
771+ )
772+ def test_alloc_constant_folding (self , subtensor_fn , expected_grad_n_alloc ):
762773 test_params = np .asarray (self .rng .standard_normal (50 * 60 ), self .dtype )
763774
764775 some_vector = vector ("some_vector" , dtype = self .dtype )
765776 some_matrix = some_vector .reshape ((60 , 50 ))
766777 variables = self .shared (np .ones ((50 ,), dtype = self .dtype ))
767- idx = constant (np .arange (50 ))
768-
769- for alloc_ , (subtensor , n_alloc ) in zip (
770- self .allocs ,
771- [
772- # IncSubtensor1
773- (some_matrix [:60 ], 2 ),
774- # AdvancedIncSubtensor1
775- (some_matrix [arange (60 )], 2 ),
776- # AdvancedIncSubtensor
777- (some_matrix [idx , idx ], 1 ),
778- ],
779- strict = True ,
780- ):
781- derp = pt_sum (dense_dot (subtensor , variables ))
782778
783- fobj = pytensor .function ([some_vector ], derp , mode = self .mode )
784- grad_derp = pytensor .grad (derp , some_vector )
785- fgrad = pytensor .function ([some_vector ], grad_derp , mode = self .mode )
779+ subtensor = subtensor_fn (some_matrix )
786780
787- topo_obj = fobj .maker .fgraph .toposort ()
788- assert sum (isinstance (node .op , type (alloc_ )) for node in topo_obj ) == 0
781+ derp = pt_sum (dense_dot (subtensor , variables ))
782+ fobj = pytensor .function ([some_vector ], derp , mode = self .mode )
783+ assert (
784+ sum (isinstance (node .op , Alloc ) for node in fobj .maker .fgraph .apply_nodes )
785+ == 0
786+ )
787+ # TODO: Assert something about the value if we bothered to call it?
788+ fobj (test_params )
789789
790- topo_grad = fgrad .maker .fgraph .toposort ()
791- assert (
792- sum (isinstance (node .op , type (alloc_ )) for node in topo_grad ) == n_alloc
793- ), (alloc_ , subtensor , n_alloc , topo_grad )
794- fobj (test_params )
795- fgrad (test_params )
790+ grad_derp = pytensor .grad (derp , some_vector )
791+ fgrad = pytensor .function ([some_vector ], grad_derp , mode = self .mode )
792+ assert (
793+ sum (isinstance (node .op , Alloc ) for node in fgrad .maker .fgraph .apply_nodes )
794+ == expected_grad_n_alloc
795+ )
796+ # TODO: Assert something about the value if we bothered to call it?
797+ fgrad (test_params )
796798
797799 def test_alloc_output (self ):
798800 val = constant (self .rng .standard_normal ((1 , 1 )), dtype = self .dtype )
0 commit comments