Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions contrib/HSDF-Net/ChamferDistancePytorch/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*__pycache__*
/tmp
21 changes: 21 additions & 0 deletions contrib/HSDF-Net/ChamferDistancePytorch/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2019 ThibaultGROUEIX

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
103 changes: 103 additions & 0 deletions contrib/HSDF-Net/ChamferDistancePytorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
`pip install torch ninja`

# Pytorch Chamfer Distance.

Include a **CUDA** version, and a **PYTHON** version with pytorch standard operations.
NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt thresholds accordingly.

- [x] F - Score



### CUDA VERSION

- [x] JIT compilation
- [x] Supports multi-gpu
- [x] 2D point clouds.
- [x] 3D point clouds.
- [x] 5D point clouds.
- [x] Contiguous() safe.



### Python Version

- [x] Supports any dimension



### Usage

```python
import torch, chamfer3D.dist_chamfer_3D, fscore
chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
points1 = torch.rand(32, 1000, 3).cuda()
points2 = torch.rand(32, 2000, 3, requires_grad=True).cuda()
dist1, dist2, idx1, idx2 = chamLoss(points1, points2)
f_score, precision, recall = fscore.fscore(dist1, dist2)
```



### Add it to your project as a submodule

```shell
git submodule add https://github.com/ThibaultGROUEIX/ChamferDistancePytorch
```



### Benchmark: [forward + backward] pass
- [x] CUDA 10.1, NVIDIA 435, Pytorch 1.4
- [x] p1 : 32 x 2000 x dim
- [x] p2 : 32 x 1000 x dim

| *Timing (sec * 1000)* | 2D | 3D | 5D |
| ---------- | -------- | ------- | ------- |
| **Cuda Compiled** | **1.2** | 1.4 |1.8 |
| **Cuda JIT** | 1.3 | **1.4** |**1.5** |
| **Python** | 37 | 37 | 37 |


| *Memory (MB)* | 2D | 3D | 5D |
| ---------- | -------- | ------- | ------- |
| **Cuda Compiled** | 529 | 529 | 549 |
| **Cuda JIT** | **520** | **529** |**549** |
| **Python** | 2495 | 2495 | 2495 |



### What is the chamfer distance ?

[Stanford course](http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf) on 3D deep Learning



### Aknowledgment

Original backbone from [Fei Xia](https://github.com/fxia22/pointGAN/blob/master/nndistance/src/nnd_cuda.cu).

JIT cool trick from [Christian Diller](https://github.com/chrdiller)

### Troubleshoot

- `Undefined symbol: Zxxxxxxxxxxxxxxxxx `:

--> Fix: Make sure to `import torch` before you `import chamfer`.
--> Use pytorch.version >= 1.1.0

- [RuntimeError: Ninja is required to load C++ extension](https://github.com/zhanghang1989/PyTorch-Encoding/issues/167)

```shell
wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
sudo unzip ninja-linux.zip -d /usr/local/bin/
sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force
```





#### TODO:

* Discuss behaviour of torch.min() and tensor.min() which causes issues in some pytorch versions
182 changes: 182 additions & 0 deletions contrib/HSDF-Net/ChamferDistancePytorch/chamfer2D/chamfer2D.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@

#include <stdio.h>
#include <ATen/ATen.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>



__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
const int batch=512;
__shared__ float buf[batch*2];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int k2=0;k2<m;k2+=batch){
int end_k=min(m,k2+batch)-k2;
for (int j=threadIdx.x;j<end_k*2;j+=blockDim.x){
buf[j]=xyz2[(i*m+k2)*2+j];
}
__syncthreads();
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz[(i*n+j)*2+0];
float y1=xyz[(i*n+j)*2+1];
int best_i=0;
float best=0;
int end_ka=end_k-(end_k&2);
if (end_ka==batch){
for (int k=0;k<batch;k+=4){
{
float x2=buf[k*2+0]-x1;
float y2=buf[k*2+1]-y1;
float d=x2*x2+y2*y2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*2+2]-x1;
float y2=buf[k*2+3]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*2+4]-x1;
float y2=buf[k*2+5]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*2+6]-x1;
float y2=buf[k*2+7]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}else{
for (int k=0;k<end_ka;k+=4){
{
float x2=buf[k*2+0]-x1;
float y2=buf[k*2+1]-y1;
float d=x2*x2+y2*y2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*2+2]-x1;
float y2=buf[k*2+3]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*2+4]-x1;
float y2=buf[k*2+5]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*2+6]-x1;
float y2=buf[k*2+7]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}
for (int k=end_ka;k<end_k;k++){
float x2=buf[k*2+0]-x1;
float y2=buf[k*2+1]-y1;
float d=x2*x2+y2*y2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
if (k2==0 || result[(i*n+j)]>best){
result[(i*n+j)]=best;
result_i[(i*n+j)]=best_i;
}
}
__syncthreads();
}
}
}
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){

const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B

NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;


}
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz1[(i*n+j)*2+0];
float y1=xyz1[(i*n+j)*2+1];
int j2=idx1[i*n+j];
float x2=xyz2[(i*m+j2)*2+0];
float y2=xyz2[(i*m+j2)*2+1];
float g=grad_dist1[i*n+j]*2;
atomicAdd(&(grad_xyz1[(i*n+j)*2+0]),g*(x1-x2));
atomicAdd(&(grad_xyz1[(i*n+j)*2+1]),g*(y1-y2));
atomicAdd(&(grad_xyz2[(i*m+j2)*2+0]),-(g*(x1-x2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*2+1]),-(g*(y1-y2)));
}
}
}
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
// cudaMemset(grad_xyz1,0,b*n*3*4);
// cudaMemset(grad_xyz2,0,b*m*3*4);

const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B

NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;

}

33 changes: 33 additions & 0 deletions contrib/HSDF-Net/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <torch/torch.h>
#include <vector>

///TMP
//#include "common.h"
/// NOT TMP


int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);


int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);




int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
}


int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {

return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
}



PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
}
Loading