@@ -716,6 +716,32 @@ def test_masked_array_not_implemented(
716716 ptb .as_tensor (x )
717717
718718
719+ def check_alloc_runtime_broadcast (mode ):
720+ """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
721+ floatX = config .floatX
722+ x_v = vector ("x" , shape = (None ,))
723+
724+ out = alloc (x_v , 5 , 3 )
725+ f = pytensor .function ([x_v ], out , mode = mode )
726+ TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
727+
728+ np .testing .assert_array_equal (
729+ f (x = np .zeros ((3 ,), dtype = floatX )),
730+ np .zeros ((5 , 3 ), dtype = floatX ),
731+ )
732+ with pytest .raises (ValueError , match = "Runtime broadcasting not allowed" ):
733+ f (x = np .zeros ((1 ,), dtype = floatX ))
734+
735+ out = alloc (specify_shape (x_v , (1 ,)), 5 , 3 )
736+ f = pytensor .function ([x_v ], out , mode = mode )
737+ TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
738+
739+ np .testing .assert_array_equal (
740+ f (x = np .zeros ((1 ,), dtype = floatX )),
741+ np .zeros ((5 , 3 ), dtype = floatX ),
742+ )
743+
744+
719745class TestAlloc :
720746 dtype = config .floatX
721747 mode = mode_opt
@@ -729,32 +755,6 @@ def check_allocs_in_fgraph(fgraph, n):
729755 == n
730756 )
731757
732- @staticmethod
733- def check_runtime_broadcast (mode ):
734- """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
735- floatX = config .floatX
736- x_v = vector ("x" , shape = (None ,))
737-
738- out = alloc (x_v , 5 , 3 )
739- f = pytensor .function ([x_v ], out , mode = mode )
740- TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
741-
742- np .testing .assert_array_equal (
743- f (x = np .zeros ((3 ,), dtype = floatX )),
744- np .zeros ((5 , 3 ), dtype = floatX ),
745- )
746- with pytest .raises (ValueError , match = "Runtime broadcasting not allowed" ):
747- f (x = np .zeros ((1 ,), dtype = floatX ))
748-
749- out = alloc (specify_shape (x_v , (1 ,)), 5 , 3 )
750- f = pytensor .function ([x_v ], out , mode = mode )
751- TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
752-
753- np .testing .assert_array_equal (
754- f (x = np .zeros ((1 ,), dtype = floatX )),
755- np .zeros ((5 , 3 ), dtype = floatX ),
756- )
757-
758758 def setup_method (self ):
759759 self .rng = np .random .default_rng (seed = utt .fetch_seed ())
760760
@@ -911,7 +911,7 @@ def test_alloc_of_view_linker(self):
911911
912912 @pytest .mark .parametrize ("mode" , (Mode ("py" ), Mode ("c" )))
913913 def test_runtime_broadcast (self , mode ):
914- self . check_runtime_broadcast (mode )
914+ check_alloc_runtime_broadcast (mode )
915915
916916
917917def test_infer_static_shape ():
0 commit comments