Skip to content

Commit be43535

Browse files
committed
Specialize Zero Alloc
1 parent 9e47a30 commit be43535

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

pytensor/tensor/basic.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,11 @@ def c_code(self, node, name, inp, out, sub):
16481648
o_static_shape = node.outputs[0].type.shape
16491649
v_ndim = len(v_static_shape)
16501650
o_ndim = len(o_static_shape)
1651+
is_zero = (
1652+
all(node.inputs[0].type.broadcastable)
1653+
and isinstance(node.inputs[0], Constant)
1654+
and (node.inputs[0].unique_value == 0)
1655+
)
16511656
assert o_ndim == len(inp[1:])
16521657

16531658
# Declare variables
@@ -1688,16 +1693,18 @@ def c_code(self, node, name, inp, out, sub):
16881693
{fail}
16891694
}}
16901695
}}
1691-
1696+
if ({int(is_zero)} && (PyArray_IS_C_CONTIGUOUS({zz}) || PyArray_IS_F_CONTIGUOUS({zz}))){{
1697+
PyArray_FILLWBYTE({zz}, 0);
1698+
}}
16921699
// This function takes care of broadcasting
1693-
if (PyArray_CopyInto({zz}, {vv}) == -1)
1700+
else if (PyArray_CopyInto({zz}, {vv}) == -1)
16941701
{fail}
16951702
"""
16961703

16971704
return code
16981705

16991706
def c_code_cache_version(self):
1700-
return (4,)
1707+
return (5,)
17011708

17021709
def infer_shape(self, fgraph, node, input_shapes):
17031710
return [node.inputs[1:]]

0 commit comments

Comments
 (0)