dht_func.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import torch
  2. import deep_hough as dh
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import time
  6. class C_dht_Function(torch.autograd.Function):
  7. @staticmethod
  8. def forward(ctx, feat, numangle, numrho):
  9. N, C, _, _ = feat.size()
  10. out = torch.zeros(N, C, numangle, numrho).type_as(feat).cuda()
  11. out = dh.forward(feat, out, numangle, numrho)
  12. outputs = out[0]
  13. ctx.save_for_backward(feat)
  14. ctx.numangle = numangle
  15. ctx.numrho = numrho
  16. return outputs
  17. @staticmethod
  18. def backward(ctx, grad_output):
  19. feat = ctx.saved_tensors[0]
  20. numangle = ctx.numangle
  21. numrho = ctx.numrho
  22. out = torch.zeros_like(feat).type_as(feat).cuda()
  23. out = dh.backward(grad_output.contiguous(), out, feat, numangle, numrho)
  24. grad_in = out[0]
  25. return grad_in, None, None
  26. class C_dht(torch.nn.Module):
  27. def __init__(self, numAngle, numRho):
  28. super(C_dht, self).__init__()
  29. self.numAngle = numAngle
  30. self.numRho = numRho
  31. def forward(self, feat):
  32. return C_dht_Function.apply(feat, self.numAngle, self.numRho)