Skip to content

Conversation

@finnBsch
Copy link

@finnBsch finnBsch commented Sep 11, 2025

Add KD-Tree Alternative for KNN Alignment

Overview

This PR introduces an alternative implementation for K-Nearest Neighbors (KNN) alignment in the depth completion pipeline using KD-trees, using torch_kdtree This is to improve inference time. Users can now choose between the existing torch_cluster implementation and the new KD-tree approach.

Changes

Core Implementation

  • Added _knn_aligns_kdtree() method as an alternative to _knn_aligns_torch_cluster()
  • Modified knn_aligns() to dispatch between implementations based on the kd_tree parameter
  • Updated forward() and kss_completer() methods to accept the kd_tree boolean parameter

This change is fully backward compatible. Existing code will continue to work unchanged, if setting the kd_tree parameter to false.

API Changes

# New parameter in forward() method
def forward(self, ..., kd_tree: bool = True):

# New parameter in kss_completer() method  
def kss_completer(self, ..., kd_tree: bool = True):

Performance Characteristics

Performance varies based on image resolution and sparsity levels. The KD-tree implementation can provide substantial speedups for high-resolution images:

Example Performance (1920x1280 resolution):

  • 60% sparsity: KD-tree is 3.91x faster (0.84s vs 3.27s)
  • 90% sparsity: KD-tree is 4.89x faster (0.24s vs 1.20s)

However, for smaller images, torch_cluster may be faster.

Validation and Testing

Benchmark Script

Benchmark script benchmark_knn_direct.py is included to measure execution time and correctness validation.

Usage:

python benchmark_knn_direct.py \
    --sparsity 0.60 0.90 \
    --runs 5 \
    --K 5 \
    --height 1280 \
    --width 1920

Parameters:

  • --sparsity: List of sparsity ratios to test (default: [0.60, 0.90])
  • --runs: Number of benchmark runs per test (default: 5)
  • --K: Number of nearest neighbors (default: 5)
  • --height: Image height for testing (default: 1280)
  • --width: Image width for testing (default: 1920)

Correctness Validation

The benchmark includes distance validation to ensure both implementations produce mathematically equivalent results. Any differences are due to legitimate tie-breaking when multiple points are equidistant, which can lead to different neighbors being returned. Hence, the correctness tests check for three conditions:

  • reported knn distances are equivalent (should be true even for different tie-breaking behavior)
  • reported neighbor-indices are valid
  • reported neighbor-indices produce the correct distances (checked against manual computation)
  • we do not check that the actual knn-indices are the same, since there are multiple "correct" solutions, if we have multiple equidistant neighbors

Dependencies

The KD-tree implementation requires:

pip install torch_kdtree

Or install from source, which is what I did and tested with, and which is why I have not added it to the requirements.

Use the provided benchmark script to make an informed decision for your use case.

Addresses #4, #22 , and maybe #23.

cc @Ammoniro @HeoeH

@huhuman
Copy link

huhuman commented Oct 22, 2025

Thanks for the work! I used it and was able to obtain comparable results for hundreds of images in my reconstruction case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants