@@ -2537,7 +2537,7 @@ def perform(self, node, inputs, output_storage):
2537
2537
)
2538
2538
2539
2539
def c_code_cache_version (self ):
2540
- return (6 ,)
2540
+ return (7 ,)
2541
2541
2542
2542
def c_code (self , node , name , inputs , outputs , sub ):
2543
2543
axis , * arrays = inputs
@@ -2576,16 +2576,86 @@ def c_code(self, node, name, inputs, outputs, sub):
2576
2576
code = f"""
2577
2577
int axis = { axis_def }
2578
2578
PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2579
- PyObject* arrays_tuple = PyTuple_New( { n } ) ;
2579
+ int out_is_valid = { out } != NULL ;
2580
2580
2581
2581
{ axis_check }
2582
2582
2583
- Py_XDECREF({ out } );
2584
- { copy_arrays_to_tuple }
2585
- { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2586
- Py_DECREF(arrays_tuple);
2587
- if(!{ out } ){{
2588
- { fail }
2583
+ if (out_is_valid) {{
2584
+ // Check if we can reuse output
2585
+ npy_intp join_size = 0;
2586
+ npy_intp out_shape[{ ndim } ];
2587
+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2588
+
2589
+ for (int i = 0; i < { n } ; i++) {{
2590
+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2591
+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2592
+ { fail }
2593
+ }}
2594
+
2595
+ join_size += PyArray_SHAPE(arrays[i])[axis];
2596
+
2597
+ if (i > 0){{
2598
+ for (int j = 0; j < { ndim } ; j++) {{
2599
+ if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2600
+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2601
+ { fail }
2602
+ }}
2603
+ }}
2604
+ }}
2605
+ }}
2606
+
2607
+ memcpy(out_shape, shape, { ndim } * sizeof(npy_intp));
2608
+ out_shape[axis] = join_size;
2609
+
2610
+ for (int i = 0; i < { ndim } ; i++) {{
2611
+ out_is_valid &= (PyArray_SHAPE({ out } )[i] == out_shape[i]);
2612
+ }}
2613
+ }}
2614
+
2615
+ if (!out_is_valid) {{
2616
+ // Use PyArray_Concatenate
2617
+ Py_XDECREF({ out } );
2618
+ PyObject* arrays_tuple = PyTuple_New({ n } );
2619
+ { copy_arrays_to_tuple }
2620
+ { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2621
+ Py_DECREF(arrays_tuple);
2622
+ if(!{ out } ){{
2623
+ { fail }
2624
+ }}
2625
+ }}
2626
+ else {{
2627
+ // Copy the data to the pre-allocated output buffer
2628
+
2629
+ // Create view into output buffer
2630
+ PyArrayObject_fields *view;
2631
+
2632
+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2633
+ Py_INCREF(PyArray_DESCR({ out } ));
2634
+ view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2635
+ PyArray_DESCR({ out } ),
2636
+ { ndim } ,
2637
+ PyArray_SHAPE(arrays[0]),
2638
+ PyArray_STRIDES({ out } ),
2639
+ PyArray_DATA({ out } ),
2640
+ NPY_ARRAY_WRITEABLE,
2641
+ NULL);
2642
+ if (view == NULL) {{
2643
+ { fail }
2644
+ }}
2645
+
2646
+ // Copy data into output buffer
2647
+ for (int i = 0; i < { n } ; i++) {{
2648
+ view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2649
+
2650
+ if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2651
+ Py_DECREF(view);
2652
+ { fail }
2653
+ }}
2654
+
2655
+ view->data += (view->dimensions[axis] * view->strides[axis]);
2656
+ }}
2657
+
2658
+ Py_DECREF(view);
2589
2659
}}
2590
2660
"""
2591
2661
return code
0 commit comments