resnet.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import math
  2. import torch.nn as nn
  3. import torch.utils.model_zoo as model_zoo
  4. class Bottleneck(nn.Module):
  5. expansion = 4
  6. def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
  7. super(Bottleneck, self).__init__()
  8. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  9. self.bn1 = BatchNorm(planes)
  10. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
  11. dilation=dilation, padding=dilation, bias=False)
  12. self.bn2 = BatchNorm(planes)
  13. self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
  14. self.bn3 = BatchNorm(planes * 4)
  15. self.relu = nn.ReLU(inplace=True)
  16. self.downsample = downsample
  17. self.stride = stride
  18. self.dilation = dilation
  19. def forward(self, x):
  20. residual = x
  21. out = self.conv1(x)
  22. out = self.bn1(out)
  23. out = self.relu(out)
  24. out = self.conv2(out)
  25. out = self.bn2(out)
  26. out = self.relu(out)
  27. out = self.conv3(out)
  28. out = self.bn3(out)
  29. if self.downsample is not None:
  30. residual = self.downsample(x)
  31. out += residual
  32. out = self.relu(out)
  33. return out
  34. class ResNet(nn.Module):
  35. def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True):
  36. self.inplanes = 64
  37. super(ResNet, self).__init__()
  38. blocks = [1, 2, 4]
  39. if output_stride == 16:
  40. strides = [1, 2, 2, 1]
  41. dilations = [1, 1, 1, 2]
  42. elif output_stride == 8:
  43. strides = [1, 2, 1, 1]
  44. dilations = [1, 1, 2, 4]
  45. else:
  46. raise NotImplementedError
  47. # Modules
  48. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
  49. bias=False)
  50. self.bn1 = BatchNorm(64)
  51. self.relu = nn.ReLU(inplace=True)
  52. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  53. self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
  54. self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
  55. self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
  56. self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
  57. # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
  58. self._init_weight()
  59. if pretrained:
  60. self._load_pretrained_model()
  61. def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
  62. downsample = None
  63. if stride != 1 or self.inplanes != planes * block.expansion:
  64. downsample = nn.Sequential(
  65. nn.Conv2d(self.inplanes, planes * block.expansion,
  66. kernel_size=1, stride=stride, bias=False),
  67. BatchNorm(planes * block.expansion),
  68. )
  69. layers = []
  70. layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
  71. self.inplanes = planes * block.expansion
  72. for i in range(1, blocks):
  73. layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
  74. return nn.Sequential(*layers)
  75. def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
  76. downsample = None
  77. if stride != 1 or self.inplanes != planes * block.expansion:
  78. downsample = nn.Sequential(
  79. nn.Conv2d(self.inplanes, planes * block.expansion,
  80. kernel_size=1, stride=stride, bias=False),
  81. BatchNorm(planes * block.expansion),
  82. )
  83. layers = []
  84. layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
  85. downsample=downsample, BatchNorm=BatchNorm))
  86. self.inplanes = planes * block.expansion
  87. for i in range(1, len(blocks)):
  88. layers.append(block(self.inplanes, planes, stride=1,
  89. dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
  90. return nn.Sequential(*layers)
  91. def forward(self, input):
  92. x = self.conv1(input)
  93. x = self.bn1(x)
  94. x = self.relu(x)
  95. x = self.maxpool(x)
  96. x = self.layer1(x) # x4
  97. x = self.layer2(x) #x8
  98. x = self.layer3(x) #x16
  99. x = self.layer4(x) #x16
  100. return x
  101. def _init_weight(self):
  102. for m in self.modules():
  103. if isinstance(m, nn.Conv2d):
  104. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  105. m.weight.data.normal_(0, math.sqrt(2. / n))
  106. elif isinstance(m, nn.BatchNorm2d):
  107. m.weight.data.fill_(1)
  108. m.bias.data.zero_()
  109. def _load_pretrained_model(self):
  110. pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
  111. model_dict = {}
  112. state_dict = self.state_dict()
  113. for k, v in pretrain_dict.items():
  114. if k in state_dict:
  115. model_dict[k] = v
  116. state_dict.update(model_dict)
  117. self.load_state_dict(state_dict)
  118. def ResNet101(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True):
  119. """Constructs a ResNet-101 model.
  120. Args:
  121. pretrained (bool): If True, returns a model pre-trained on ImageNet
  122. """
  123. model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)
  124. return model
  125. def ResNet50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True):
  126. """Constructs a ResNet-101 model.
  127. Args:
  128. pretrained (bool): If True, returns a model pre-trained on ImageNet
  129. """
  130. model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained)
  131. return model
  132. if __name__ == "__main__":
  133. import torch
  134. model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
  135. input = torch.rand(1, 3, 480, 640)
  136. output = model(input)
  137. print(output.size())
  138. # print(low_level_feat.size())