network.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import torch
  2. import numpy as np
  3. import torch.nn as nn
  4. from model.backbone.fpn import FPN101, FPN50, FPN18, ResNext50_FPN
  5. from model.backbone.mobilenet import MobileNet_FPN
  6. from model.backbone.vgg_fpn import VGG_FPN
  7. from model.backbone.res2net import res2net50_FPN
  8. from model.dht import DHT_Layer
  9. class Net(nn.Module):
  10. def __init__(self, numAngle, numRho, backbone):
  11. super(Net, self).__init__()
  12. if backbone == 'resnet18':
  13. self.backbone = FPN18(pretrained=True, output_stride=32)
  14. output_stride = 32
  15. if backbone == 'resnet50':
  16. self.backbone = FPN50(pretrained=True, output_stride=16)
  17. output_stride = 16
  18. if backbone == 'resnet101':
  19. self.backbone = FPN101(output_stride=16)
  20. output_stride = 16
  21. if backbone == 'resnext50':
  22. self.backbone = ResNext50_FPN(output_stride=16)
  23. output_stride = 16
  24. if backbone == 'vgg16':
  25. self.backbone = VGG_FPN()
  26. output_stride = 16
  27. if backbone == 'mobilenetv2':
  28. self.backbone = MobileNet_FPN()
  29. output_stride = 32
  30. if backbone == 'res2net50':
  31. self.backbone = res2net50_FPN()
  32. output_stride = 32
  33. if backbone == 'mobilenetv2':
  34. self.dht_detector1 = DHT_Layer(32, 32, numAngle=numAngle, numRho=numRho)
  35. self.dht_detector2 = DHT_Layer(32, 32, numAngle=numAngle, numRho=numRho // 2)
  36. self.dht_detector3 = DHT_Layer(32, 32, numAngle=numAngle, numRho=numRho // 4)
  37. self.dht_detector4 = DHT_Layer(32, 32, numAngle=numAngle, numRho=numRho // (output_stride // 4))
  38. self.last_conv = nn.Sequential(
  39. nn.Conv2d(128, 1, 1)
  40. )
  41. else:
  42. self.dht_detector1 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho)
  43. self.dht_detector2 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho // 2)
  44. self.dht_detector3 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho // 4)
  45. self.dht_detector4 = DHT_Layer(256, 128, numAngle=numAngle, numRho=numRho // (output_stride // 4))
  46. self.last_conv = nn.Sequential(
  47. nn.Conv2d(512, 1, 1)
  48. )
  49. self.numAngle = numAngle
  50. self.numRho = numRho
  51. def upsample_cat(self, p1, p2, p3, p4):
  52. p1 = nn.functional.interpolate(p1, size=(self.numAngle, self.numRho), mode='bilinear')
  53. p2 = nn.functional.interpolate(p2, size=(self.numAngle, self.numRho), mode='bilinear')
  54. p3 = nn.functional.interpolate(p3, size=(self.numAngle, self.numRho), mode='bilinear')
  55. p4 = nn.functional.interpolate(p4, size=(self.numAngle, self.numRho), mode='bilinear')
  56. return torch.cat([p1, p2, p3, p4], dim=1)
  57. def forward(self, x):
  58. p1, p2, p3, p4 = self.backbone(x)
  59. p1 = self.dht_detector1(p1)
  60. p2 = self.dht_detector2(p2)
  61. p3 = self.dht_detector3(p3)
  62. p4 = self.dht_detector4(p4)
  63. cat = self.upsample_cat(p1, p2, p3, p4)
  64. logist = self.last_conv(cat)
  65. return logist