dht.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import jittor as jt
  2. from jittor import Function, nn
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from model.cuda_src import cuda_src_forward as csf
  6. from model.cuda_src import cuda_src_backward as csb
  7. class DHT_Func(Function):
  8. def execute(self, x, numangle, numrho):
  9. n, c, h, w = x.shape
  10. cuda_src_forward = csf.replace('#numangle', str(numangle))
  11. cuda_src_forward = cuda_src_forward.replace('#numrho', str(numrho))
  12. irho = int((h*h + w*w)**0.5 + 1) / float((numrho - 1))
  13. itheta = 3.14159265358979323846 / numangle
  14. angle = jt.arange(numangle) * itheta
  15. tabCos = angle.cos() / irho
  16. tabSin = angle.sin() / irho
  17. output = jt.code([n, c, numangle, numrho], x.dtype, [x, tabCos, tabSin],
  18. cuda_src=cuda_src_forward)
  19. self.save_vars = x, numangle, numrho
  20. return output
  21. def grad(self, grad):
  22. x, numangle, numrho = self.save_vars
  23. cuda_src_backward = csb.replace('#numangle', str(numangle))
  24. cuda_src_backward = cuda_src_backward.replace('#numrho', str(numrho))
  25. irho = int((h*h + w*w)**0.5 + 1) / float((numrho - 1))
  26. itheta = 3.14159265358979323846 / numangle
  27. angle = jt.arange(numangle) * itheta
  28. tabCos = angle.cos() / irho
  29. tabSin = angle.sin() / irho
  30. return jt.code([x.shape], [x.dtype], [x, grad, tabCos, tabSin],
  31. cuda_src=cuda_src_backward)
  32. class C_dht(nn.Module):
  33. def __init__(self, numAngle, numRho):
  34. super(C_dht, self).__init__()
  35. self.numAngle = numAngle
  36. self.numRho = numRho
  37. def execute(self, feat):
  38. return DHT_Func.apply(feat, self.numAngle, self.numRho)
  39. class DHT(nn.Module):
  40. def __init__(self, numAngle, numRho):
  41. super(DHT, self).__init__()
  42. self.line_agg = C_dht(numAngle, numRho)
  43. def execute(self, x):
  44. accum = self.line_agg(x)
  45. return accum
  46. class DHT_Layer(nn.Module):
  47. def __init__(self, input_dim, dim, numAngle, numRho):
  48. super(DHT_Layer, self).__init__()
  49. self.fist_conv = nn.Sequential(
  50. nn.Conv2d(input_dim, dim, 1),
  51. nn.BatchNorm2d(dim),
  52. nn.ReLU()
  53. )
  54. self.dht = DHT(numAngle=numAngle, numRho=numRho)
  55. self.convs = nn.Sequential(
  56. nn.Conv2d(dim, dim, 3, 1, 1),
  57. nn.BatchNorm2d(dim),
  58. nn.ReLU(),
  59. nn.Conv2d(dim, dim, 3, 1, 1),
  60. nn.BatchNorm2d(dim),
  61. nn.ReLU()
  62. )
  63. def execute(self, x):
  64. x = self.fist_conv(x)
  65. x = self.dht(x)
  66. x = self.convs(x)
  67. return x