network.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import jittor as jt
  2. import numpy as np
  3. from jittor import nn
  4. from model.backbone.fpn import FPN50
  5. from model.dht import DHT_Layer
  6. class Net(nn.Module):
  7. def __init__(self, numAngle, numRho, backbone):
  8. super(Net, self).__init__()
  9. if backbone == 'resnet50':
  10. print('using resnet50 backbone.')
  11. self.backbone = FPN50(pretrained=True, output_stride=16)
  12. output_stride = 16
  13. else:
  14. raise NotImplementedError
  15. self.dht_detector1 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho)
  16. self.dht_detector2 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho // 2)
  17. self.dht_detector3 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho // 4)
  18. self.dht_detector4 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho // (output_stride // 4))
  19. self.last_conv = nn.Sequential(
  20. nn.Conv2d(512, 1, 1)
  21. )
  22. self.numAngle = numAngle
  23. self.numRho = numRho
  24. def upsample_cat(self, p1, p2, p3, p4):
  25. p1 = nn.interpolate(p1, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True)
  26. p2 = nn.interpolate(p2, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True)
  27. p3 = nn.interpolate(p3, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True)
  28. p4 = nn.interpolate(p4, size=(self.numAngle, self.numRho), mode='bilinear', align_corners=True)
  29. return jt.concat([p1, p2, p3, p4], dim=1)
  30. def execute(self, x):
  31. p1, p2, p3, p4 = self.backbone(x)
  32. p4 = self.dht_detector4(p4)
  33. p3 = self.dht_detector3(p3)
  34. p2 = self.dht_detector2(p2)
  35. p1 = self.dht_detector1(p1)
  36. cat = self.upsample_cat(p1, p2, p3, p4)
  37. logist = self.last_conv(cat)
  38. return logist