vgg_fpn.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import torch
  2. import math
  3. from torch import nn
  4. # vgg16
  5. vgg_config = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
  6. def vgg(cfg, i=3, batch_norm=False):
  7. layers = []
  8. in_channels = i
  9. for v in cfg:
  10. if v == 'M':
  11. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  12. else:
  13. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
  14. if batch_norm:
  15. layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
  16. else:
  17. layers += [conv2d, nn.ReLU(inplace=True)]
  18. in_channels = v
  19. return layers
  20. class VGG_Feature(nn.Module):
  21. def __init__(self, extract=[8, 15, 22, 29]):
  22. super(VGG_Feature, self).__init__()
  23. self.vgg = nn.ModuleList(vgg(cfg=vgg_config))
  24. self.extract = extract
  25. def forward(self, x):
  26. features = []
  27. for i in range(len(self.vgg)):
  28. x = self.vgg[i](x)
  29. if i in self.extract:
  30. features.append(x)
  31. return features
  32. class VGG_FPN(nn.Module):
  33. def __init__(self):
  34. super(VGG_FPN, self).__init__()
  35. self.vgg_feat = VGG_Feature()
  36. self.toplayer = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0)
  37. self.latlayer1 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0)
  38. self.latlayer2 = nn.Conv2d( 256, 256, kernel_size=1, stride=1, padding=0)
  39. self.latlayer3 = nn.Conv2d( 128, 256, kernel_size=1, stride=1, padding=0)
  40. self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
  41. self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
  42. self.smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
  43. self._init_weight()
  44. self._load_pretrained_model()
  45. def forward(self, x):
  46. [c2, c3, c4, c5] = self.vgg_feat(x)
  47. # Top-down
  48. p5 = self.toplayer(c5)
  49. p4 = nn.functional.upsample(p5, size=c4.size()[2:], mode='bilinear') + self.latlayer1(c4)
  50. p3 = nn.functional.upsample(p4, size=c3.size()[2:], mode='bilinear') + self.latlayer2(c3)
  51. p2 = nn.functional.upsample(p3, size=c2.size()[2:], mode='bilinear') + self.latlayer3(c2)
  52. p4 = self.smooth1(p4)
  53. p3 = self.smooth2(p3)
  54. p2 = self.smooth3(p2)
  55. return p2, p3, p4, p5
  56. def _init_weight(self):
  57. for m in self.modules():
  58. if isinstance(m, nn.Conv2d):
  59. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  60. m.weight.data.normal_(0, math.sqrt(2. / n))
  61. elif isinstance(m, nn.BatchNorm2d):
  62. m.weight.data.fill_(1)
  63. m.bias.data.zero_()
  64. def _load_pretrained_model(self):
  65. self.vgg_feat.vgg.load_state_dict(torch.load('/home/hanqi/vgg16_feat.pth'))