chamfer_distance.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. #include <torch/torch.h>
  2. // CUDA forward declarations
  3. int ChamferDistanceKernelLauncher(
  4. const int b, const int n,
  5. const float* xyz,
  6. const int m,
  7. const float* xyz2,
  8. float* result,
  9. int* result_i,
  10. float* result2,
  11. int* result2_i);
  12. int ChamferDistanceGradKernelLauncher(
  13. const int b, const int n,
  14. const float* xyz1,
  15. const int m,
  16. const float* xyz2,
  17. const float* grad_dist1,
  18. const int* idx1,
  19. const float* grad_dist2,
  20. const int* idx2,
  21. float* grad_xyz1,
  22. float* grad_xyz2);
  23. void chamfer_distance_forward_cuda(
  24. const at::Tensor xyz1,
  25. const at::Tensor xyz2,
  26. const at::Tensor dist1,
  27. const at::Tensor dist2,
  28. const at::Tensor idx1,
  29. const at::Tensor idx2)
  30. {
  31. ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
  32. xyz2.size(1), xyz2.data<float>(),
  33. dist1.data<float>(), idx1.data<int>(),
  34. dist2.data<float>(), idx2.data<int>());
  35. }
  36. void chamfer_distance_backward_cuda(
  37. const at::Tensor xyz1,
  38. const at::Tensor xyz2,
  39. at::Tensor gradxyz1,
  40. at::Tensor gradxyz2,
  41. at::Tensor graddist1,
  42. at::Tensor graddist2,
  43. at::Tensor idx1,
  44. at::Tensor idx2)
  45. {
  46. ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
  47. xyz2.size(1), xyz2.data<float>(),
  48. graddist1.data<float>(), idx1.data<int>(),
  49. graddist2.data<float>(), idx2.data<int>(),
  50. gradxyz1.data<float>(), gradxyz2.data<float>());
  51. }
  52. void nnsearch(
  53. const int b, const int n, const int m,
  54. const float* xyz1,
  55. const float* xyz2,
  56. float* dist,
  57. int* idx)
  58. {
  59. for (int i = 0; i < b; i++) {
  60. for (int j = 0; j < n; j++) {
  61. const float x1 = xyz1[(i*n+j)*3+0];
  62. const float y1 = xyz1[(i*n+j)*3+1];
  63. const float z1 = xyz1[(i*n+j)*3+2];
  64. double best = 0;
  65. int besti = 0;
  66. for (int k = 0; k < m; k++) {
  67. const float x2 = xyz2[(i*m+k)*3+0] - x1;
  68. const float y2 = xyz2[(i*m+k)*3+1] - y1;
  69. const float z2 = xyz2[(i*m+k)*3+2] - z1;
  70. const double d=x2*x2+y2*y2+z2*z2;
  71. if (k==0 || d < best){
  72. best = d;
  73. besti = k;
  74. }
  75. }
  76. dist[i*n+j] = best;
  77. idx[i*n+j] = besti;
  78. }
  79. }
  80. }
  81. void chamfer_distance_forward(
  82. const at::Tensor xyz1,
  83. const at::Tensor xyz2,
  84. const at::Tensor dist1,
  85. const at::Tensor dist2,
  86. const at::Tensor idx1,
  87. const at::Tensor idx2)
  88. {
  89. const int batchsize = xyz1.size(0);
  90. const int n = xyz1.size(1);
  91. const int m = xyz2.size(1);
  92. const float* xyz1_data = xyz1.data<float>();
  93. const float* xyz2_data = xyz2.data<float>();
  94. float* dist1_data = dist1.data<float>();
  95. float* dist2_data = dist2.data<float>();
  96. int* idx1_data = idx1.data<int>();
  97. int* idx2_data = idx2.data<int>();
  98. nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data);
  99. nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data);
  100. }
  101. void chamfer_distance_backward(
  102. const at::Tensor xyz1,
  103. const at::Tensor xyz2,
  104. at::Tensor gradxyz1,
  105. at::Tensor gradxyz2,
  106. at::Tensor graddist1,
  107. at::Tensor graddist2,
  108. at::Tensor idx1,
  109. at::Tensor idx2)
  110. {
  111. const int b = xyz1.size(0);
  112. const int n = xyz1.size(1);
  113. const int m = xyz2.size(1);
  114. const float* xyz1_data = xyz1.data<float>();
  115. const float* xyz2_data = xyz2.data<float>();
  116. float* gradxyz1_data = gradxyz1.data<float>();
  117. float* gradxyz2_data = gradxyz2.data<float>();
  118. float* graddist1_data = graddist1.data<float>();
  119. float* graddist2_data = graddist2.data<float>();
  120. const int* idx1_data = idx1.data<int>();
  121. const int* idx2_data = idx2.data<int>();
  122. for (int i = 0; i < b*n*3; i++)
  123. gradxyz1_data[i] = 0;
  124. for (int i = 0; i < b*m*3; i++)
  125. gradxyz2_data[i] = 0;
  126. for (int i = 0;i < b; i++) {
  127. for (int j = 0; j < n; j++) {
  128. const float x1 = xyz1_data[(i*n+j)*3+0];
  129. const float y1 = xyz1_data[(i*n+j)*3+1];
  130. const float z1 = xyz1_data[(i*n+j)*3+2];
  131. const int j2 = idx1_data[i*n+j];
  132. const float x2 = xyz2_data[(i*m+j2)*3+0];
  133. const float y2 = xyz2_data[(i*m+j2)*3+1];
  134. const float z2 = xyz2_data[(i*m+j2)*3+2];
  135. const float g = graddist1_data[i*n+j]*2;
  136. gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2);
  137. gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2);
  138. gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2);
  139. gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2));
  140. gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2));
  141. gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2));
  142. }
  143. for (int j = 0; j < m; j++) {
  144. const float x1 = xyz2_data[(i*m+j)*3+0];
  145. const float y1 = xyz2_data[(i*m+j)*3+1];
  146. const float z1 = xyz2_data[(i*m+j)*3+2];
  147. const int j2 = idx2_data[i*m+j];
  148. const float x2 = xyz1_data[(i*n+j2)*3+0];
  149. const float y2 = xyz1_data[(i*n+j2)*3+1];
  150. const float z2 = xyz1_data[(i*n+j2)*3+2];
  151. const float g = graddist2_data[i*m+j]*2;
  152. gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2);
  153. gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2);
  154. gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2);
  155. gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2));
  156. gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2));
  157. gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2));
  158. }
  159. }
  160. }
  161. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  162. m.def("forward", &chamfer_distance_forward, "ChamferDistance forward");
  163. m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)");
  164. m.def("backward", &chamfer_distance_backward, "ChamferDistance backward");
  165. m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)");
  166. }