chamfer_distance.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import torch
  2. from torch.utils.cpp_extension import load
  3. cd = load(name="cd",
  4. sources=["chamfer_distance/chamfer_distance.cpp",
  5. "chamfer_distance/chamfer_distance.cu"])
  6. class ChamferDistanceFunction(torch.autograd.Function):
  7. @staticmethod
  8. def forward(ctx, xyz1, xyz2):
  9. batchsize, n, _ = xyz1.size()
  10. _, m, _ = xyz2.size()
  11. xyz1 = xyz1.contiguous()
  12. xyz2 = xyz2.contiguous()
  13. dist1 = torch.zeros(batchsize, n)
  14. dist2 = torch.zeros(batchsize, m)
  15. idx1 = torch.zeros(batchsize, n, dtype=torch.int)
  16. idx2 = torch.zeros(batchsize, m, dtype=torch.int)
  17. if not xyz1.is_cuda:
  18. cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
  19. else:
  20. dist1 = dist1.cuda()
  21. dist2 = dist2.cuda()
  22. idx1 = idx1.cuda()
  23. idx2 = idx2.cuda()
  24. cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2)
  25. ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
  26. return dist1, dist2
  27. @staticmethod
  28. def backward(ctx, graddist1, graddist2):
  29. xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
  30. graddist1 = graddist1.contiguous()
  31. graddist2 = graddist2.contiguous()
  32. gradxyz1 = torch.zeros(xyz1.size())
  33. gradxyz2 = torch.zeros(xyz2.size())
  34. if not graddist1.is_cuda:
  35. cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
  36. else:
  37. gradxyz1 = gradxyz1.cuda()
  38. gradxyz2 = gradxyz2.cuda()
  39. cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
  40. return gradxyz1, gradxyz2
  41. class ChamferDistance(torch.nn.Module):
  42. def forward(self, xyz1, xyz2):
  43. return ChamferDistanceFunction.apply(xyz1, xyz2)