diff --git a/include/ftk/basic/kd_lite.hh b/include/ftk/basic/kd_lite.hh index e78ce037..0c5600cf 100644 --- a/include/ftk/basic/kd_lite.hh +++ b/include/ftk/basic/kd_lite.hh @@ -8,9 +8,11 @@ namespace ftk { +// implementation of https://www.nvidia.com/content/gtc-2010/pdfs/2140_gtc2010.pdf template __host__ void kd_build_recursive( + const I n, const I current, const F *X, // coordinates const I level, // the current level @@ -20,26 +22,35 @@ void kd_build_recursive( I *ids) // out: pre-allocated array for ids { const I axis = level % nd; - const I half = length / 2; + const I h = std::ceil(std::log2(length+1)); + const I half = std::pow(2, h-2); + const I lastRow = length - 2 * half + 1; + const I lbm = half + std::min(half, lastRow); + // const I half = length / 2; // fprintf(stderr, "current=%d, offset=%d, length=%d\n", // current, offset, length); - std::nth_element( - ids + offset, - ids + offset + half, // median - ids + offset + length, - [X, axis](I i, I j) { - return X[i*nd+axis] < X[j*nd+axis]; - }); - - heap[current] = ids[offset + half]; // the median - // fprintf(stderr, "current=%d, offset=%d, length=%d, median=%d\n", current, offset, length, heap[current]); - - if (half >= 1) - kd_build_recursive(current*2+1, X, level+1, offset, half-1, heap, ids); // left - if (length - half >= 1) - kd_build_recursive(current*2+2, X, level+1, offset+half+1, length-half-1, heap, ids); // right + if (length == 1) { + heap[current] = ids[offset]; + // fprintf(stderr, "current=%d, offset=%d, length=%d, lbm=%d, median=%d\n", current, offset, length, lbm, heap[current]); + } else { + std::nth_element( + ids + offset, + ids + offset + lbm-1, // median + ids + offset + length, + [X, axis](I i, I j) { + return X[i*nd+axis] < X[j*nd+axis]; + }); + + heap[current] = ids[offset + lbm-1]; // the median + // fprintf(stderr, "current=%d, offset=%d, length=%d, lbm=%d, median=%d\n", current, offset, length, lbm, heap[current]); + + if (lbm - 1 >= 1) + kd_build_recursive(n, current*2+1, X, level+1, offset, lbm-1, heap, ids); // left + if (length - lbm >= 1) + kd_build_recursive(n, current*2+2, X, level+1, offset+lbm, length-lbm, heap, ids); // right + } } template @@ -54,7 +65,10 @@ void kd_build( for (int i = 0; i < n; i ++) ids[i] = i; - kd_build_recursive(0, X, 0, 0, n, heap, ids.data()); + kd_build_recursive(n, 0, X, 0, 0, n, heap, ids.data()); + + // for (int i = 0; i < n; i ++) + // fprintf(stderr, "i=%d, heap=%d\n", i, heap[i]); } template