Skip to content

Commit 4d3f133

Browse files
committed
Reuse output buffer in C-impl of Join
1 parent 7205f36 commit 4d3f133

File tree

2 files changed

+107
-10
lines changed

2 files changed

+107
-10
lines changed

pytensor/tensor/basic.py

+78-8
Original file line numberDiff line numberDiff line change
@@ -2537,7 +2537,7 @@ def perform(self, node, inputs, output_storage):
25372537
)
25382538

25392539
def c_code_cache_version(self):
2540-
return (6,)
2540+
return (7,)
25412541

25422542
def c_code(self, node, name, inputs, outputs, sub):
25432543
axis, *arrays = inputs
@@ -2576,16 +2576,86 @@ def c_code(self, node, name, inputs, outputs, sub):
25762576
code = f"""
25772577
int axis = {axis_def}
25782578
PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}};
2579-
PyObject* arrays_tuple = PyTuple_New({n});
2579+
int out_is_valid = {out} != NULL;
25802580
25812581
{axis_check}
25822582
2583-
Py_XDECREF({out});
2584-
{copy_arrays_to_tuple}
2585-
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2586-
Py_DECREF(arrays_tuple);
2587-
if(!{out}){{
2588-
{fail}
2583+
if (out_is_valid) {{
2584+
// Check if we can reuse output
2585+
npy_intp join_size = 0;
2586+
npy_intp out_shape[{ndim}];
2587+
npy_intp *shape = PyArray_SHAPE(arrays[0]);
2588+
2589+
for (int i = 0; i < {n}; i++) {{
2590+
if (PyArray_NDIM(arrays[i]) != {ndim}) {{
2591+
PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2592+
{fail}
2593+
}}
2594+
2595+
join_size += PyArray_SHAPE(arrays[i])[axis];
2596+
2597+
if (i > 0){{
2598+
for (int j = 0; j < {ndim}; j++) {{
2599+
if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2600+
PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2601+
{fail}
2602+
}}
2603+
}}
2604+
}}
2605+
}}
2606+
2607+
memcpy(out_shape, shape, {ndim} * sizeof(npy_intp));
2608+
out_shape[axis] = join_size;
2609+
2610+
for (int i = 0; i < {ndim}; i++) {{
2611+
out_is_valid &= (PyArray_SHAPE({out})[i] == out_shape[i]);
2612+
}}
2613+
}}
2614+
2615+
if (!out_is_valid) {{
2616+
// Use PyArray_Concatenate
2617+
Py_XDECREF({out});
2618+
PyObject* arrays_tuple = PyTuple_New({n});
2619+
{copy_arrays_to_tuple}
2620+
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2621+
Py_DECREF(arrays_tuple);
2622+
if(!{out}){{
2623+
{fail}
2624+
}}
2625+
}}
2626+
else {{
2627+
// Copy the data to the pre-allocated output buffer
2628+
2629+
// Create view into output buffer
2630+
PyArrayObject_fields *view;
2631+
2632+
// PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2633+
Py_INCREF(PyArray_DESCR({out}));
2634+
view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2635+
PyArray_DESCR({out}),
2636+
{ndim},
2637+
PyArray_SHAPE(arrays[0]),
2638+
PyArray_STRIDES({out}),
2639+
PyArray_DATA({out}),
2640+
NPY_ARRAY_WRITEABLE,
2641+
NULL);
2642+
if (view == NULL) {{
2643+
{fail}
2644+
}}
2645+
2646+
// Copy data into output buffer
2647+
for (int i = 0; i < {n}; i++) {{
2648+
view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2649+
2650+
if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2651+
Py_DECREF(view);
2652+
{fail}
2653+
}}
2654+
2655+
view->data += (view->dimensions[axis] * view->strides[axis]);
2656+
}}
2657+
2658+
Py_DECREF(view);
25892659
}}
25902660
"""
25912661
return code

tests/tensor/test_basic.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
ivector,
118118
lscalar,
119119
lvector,
120+
matrices,
120121
matrix,
121122
row,
122123
scalar,
@@ -1762,7 +1763,7 @@ def test_join_matrixV_negative_axis(self):
17621763
got = f(-2)
17631764
assert np.allclose(got, want)
17641765

1765-
with pytest.raises(IndexError):
1766+
with pytest.raises(ValueError):
17661767
f(-3)
17671768

17681769
@pytest.mark.parametrize("py_impl", (False, True))
@@ -1805,7 +1806,7 @@ def test_join_matrixC_negative_axis(self, py_impl):
18051806
got = f()
18061807
assert np.allclose(got, want)
18071808

1808-
with pytest.raises(IndexError):
1809+
with pytest.raises(ValueError):
18091810
join(-3, a, b)
18101811

18111812
with impl_ctxt:
@@ -2156,6 +2157,32 @@ def test_split_view(self, linker):
21562157
# C impl always makes a copy
21572158
assert r.base is not x_test
21582159

2160+
@pytest.mark.parametrize("gc", (True, False), ids=lambda x: f"gc={x}")
2161+
@pytest.mark.parametrize("memory_layout", ["C-contiguous", "F-contiguous", "Mixed"])
2162+
@pytest.mark.parametrize("axis", (0, 1), ids=lambda x: f"axis={x}")
2163+
@pytest.mark.parametrize("ndim", (1, 2), ids=["vector", "matrix"])
2164+
@config.change_flags(cmodule__warn_no_version=False)
2165+
def test_join_performance(self, ndim, axis, memory_layout, gc, benchmark):
2166+
if ndim == 1 and not (memory_layout == "C-contiguous" and axis == 0):
2167+
pytest.skip("Redundant parametrization")
2168+
n = 64
2169+
inputs = vectors("abcdef") if ndim == 1 else matrices("abcdef")
2170+
out = join(axis, *inputs)
2171+
fn = pytensor.function(inputs, Out(out, borrow=True), trust_input=True)
2172+
fn.vm.allow_gc = gc
2173+
test_values = [np.zeros((n, n)[:ndim], dtype=inputs[0].dtype) for _ in inputs]
2174+
if memory_layout == "C-contiguous":
2175+
pass
2176+
elif memory_layout == "F-contiguous":
2177+
test_values = [t.T for t in test_values]
2178+
elif memory_layout == "Mixed":
2179+
test_values = [t if i % 2 else t.T for i, t in enumerate(test_values)]
2180+
else:
2181+
raise ValueError
2182+
2183+
assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6)
2184+
benchmark(fn, *test_values)
2185+
21592186

21602187
def test_TensorFromScalar():
21612188
s = ps.constant(56)

0 commit comments

Comments
 (0)