chamfer_distance.cu 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #include <ATen/ATen.h>
  2. #include <cuda.h>
  3. #include <cuda_runtime.h>
  4. __global__
  5. void ChamferDistanceKernel(
  6. int b,
  7. int n,
  8. const float* xyz,
  9. int m,
  10. const float* xyz2,
  11. float* result,
  12. int* result_i)
  13. {
  14. const int batch=512;
  15. __shared__ float buf[batch*3];
  16. for (int i=blockIdx.x;i<b;i+=gridDim.x){
  17. for (int k2=0;k2<m;k2+=batch){
  18. int end_k=min(m,k2+batch)-k2;
  19. for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
  20. buf[j]=xyz2[(i*m+k2)*3+j];
  21. }
  22. __syncthreads();
  23. for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
  24. float x1=xyz[(i*n+j)*3+0];
  25. float y1=xyz[(i*n+j)*3+1];
  26. float z1=xyz[(i*n+j)*3+2];
  27. int best_i=0;
  28. float best=0;
  29. int end_ka=end_k-(end_k&3);
  30. if (end_ka==batch){
  31. for (int k=0;k<batch;k+=4){
  32. {
  33. float x2=buf[k*3+0]-x1;
  34. float y2=buf[k*3+1]-y1;
  35. float z2=buf[k*3+2]-z1;
  36. float d=x2*x2+y2*y2+z2*z2;
  37. if (k==0 || d<best){
  38. best=d;
  39. best_i=k+k2;
  40. }
  41. }
  42. {
  43. float x2=buf[k*3+3]-x1;
  44. float y2=buf[k*3+4]-y1;
  45. float z2=buf[k*3+5]-z1;
  46. float d=x2*x2+y2*y2+z2*z2;
  47. if (d<best){
  48. best=d;
  49. best_i=k+k2+1;
  50. }
  51. }
  52. {
  53. float x2=buf[k*3+6]-x1;
  54. float y2=buf[k*3+7]-y1;
  55. float z2=buf[k*3+8]-z1;
  56. float d=x2*x2+y2*y2+z2*z2;
  57. if (d<best){
  58. best=d;
  59. best_i=k+k2+2;
  60. }
  61. }
  62. {
  63. float x2=buf[k*3+9]-x1;
  64. float y2=buf[k*3+10]-y1;
  65. float z2=buf[k*3+11]-z1;
  66. float d=x2*x2+y2*y2+z2*z2;
  67. if (d<best){
  68. best=d;
  69. best_i=k+k2+3;
  70. }
  71. }
  72. }
  73. }else{
  74. for (int k=0;k<end_ka;k+=4){
  75. {
  76. float x2=buf[k*3+0]-x1;
  77. float y2=buf[k*3+1]-y1;
  78. float z2=buf[k*3+2]-z1;
  79. float d=x2*x2+y2*y2+z2*z2;
  80. if (k==0 || d<best){
  81. best=d;
  82. best_i=k+k2;
  83. }
  84. }
  85. {
  86. float x2=buf[k*3+3]-x1;
  87. float y2=buf[k*3+4]-y1;
  88. float z2=buf[k*3+5]-z1;
  89. float d=x2*x2+y2*y2+z2*z2;
  90. if (d<best){
  91. best=d;
  92. best_i=k+k2+1;
  93. }
  94. }
  95. {
  96. float x2=buf[k*3+6]-x1;
  97. float y2=buf[k*3+7]-y1;
  98. float z2=buf[k*3+8]-z1;
  99. float d=x2*x2+y2*y2+z2*z2;
  100. if (d<best){
  101. best=d;
  102. best_i=k+k2+2;
  103. }
  104. }
  105. {
  106. float x2=buf[k*3+9]-x1;
  107. float y2=buf[k*3+10]-y1;
  108. float z2=buf[k*3+11]-z1;
  109. float d=x2*x2+y2*y2+z2*z2;
  110. if (d<best){
  111. best=d;
  112. best_i=k+k2+3;
  113. }
  114. }
  115. }
  116. }
  117. for (int k=end_ka;k<end_k;k++){
  118. float x2=buf[k*3+0]-x1;
  119. float y2=buf[k*3+1]-y1;
  120. float z2=buf[k*3+2]-z1;
  121. float d=x2*x2+y2*y2+z2*z2;
  122. if (k==0 || d<best){
  123. best=d;
  124. best_i=k+k2;
  125. }
  126. }
  127. if (k2==0 || result[(i*n+j)]>best){
  128. result[(i*n+j)]=best;
  129. result_i[(i*n+j)]=best_i;
  130. }
  131. }
  132. __syncthreads();
  133. }
  134. }
  135. }
  136. void ChamferDistanceKernelLauncher(
  137. const int b, const int n,
  138. const float* xyz,
  139. const int m,
  140. const float* xyz2,
  141. float* result,
  142. int* result_i,
  143. float* result2,
  144. int* result2_i)
  145. {
  146. ChamferDistanceKernel<<<dim3(32,16,1),512>>>(b, n, xyz, m, xyz2, result, result_i);
  147. ChamferDistanceKernel<<<dim3(32,16,1),512>>>(b, m, xyz2, n, xyz, result2, result2_i);
  148. cudaError_t err = cudaGetLastError();
  149. if (err != cudaSuccess)
  150. printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err));
  151. }
  152. __global__
  153. void ChamferDistanceGradKernel(
  154. int b, int n,
  155. const float* xyz1,
  156. int m,
  157. const float* xyz2,
  158. const float* grad_dist1,
  159. const int* idx1,
  160. float* grad_xyz1,
  161. float* grad_xyz2)
  162. {
  163. for (int i = blockIdx.x; i<b; i += gridDim.x) {
  164. for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x*gridDim.y) {
  165. float x1=xyz1[(i*n+j)*3+0];
  166. float y1=xyz1[(i*n+j)*3+1];
  167. float z1=xyz1[(i*n+j)*3+2];
  168. int j2=idx1[i*n+j];
  169. float x2=xyz2[(i*m+j2)*3+0];
  170. float y2=xyz2[(i*m+j2)*3+1];
  171. float z2=xyz2[(i*m+j2)*3+2];
  172. float g=grad_dist1[i*n+j]*2;
  173. atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
  174. atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
  175. atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
  176. atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
  177. atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
  178. atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
  179. }
  180. }
  181. }
  182. void ChamferDistanceGradKernelLauncher(
  183. const int b, const int n,
  184. const float* xyz1,
  185. const int m,
  186. const float* xyz2,
  187. const float* grad_dist1,
  188. const int* idx1,
  189. const float* grad_dist2,
  190. const int* idx2,
  191. float* grad_xyz1,
  192. float* grad_xyz2)
  193. {
  194. cudaMemset(grad_xyz1, 0, b*n*3*4);
  195. cudaMemset(grad_xyz2, 0, b*m*3*4);
  196. ChamferDistanceGradKernel<<<dim3(1,16,1), 256>>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2);
  197. ChamferDistanceGradKernel<<<dim3(1,16,1), 256>>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1);
  198. cudaError_t err = cudaGetLastError();
  199. if (err != cudaSuccess)
  200. printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err));
  201. }