mobilenet.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from torch import nn
  2. import torch.utils.model_zoo as model_zoo
  3. def _make_divisible(v, divisor, min_value=None):
  4. """
  5. This function is taken from the original tf repo.
  6. It ensures that all layers have a channel number that is divisible by 8
  7. It can be seen here:
  8. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  9. :param v:
  10. :param divisor:
  11. :param min_value:
  12. :return:
  13. """
  14. if min_value is None:
  15. min_value = divisor
  16. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  17. # Make sure that round down does not go down by more than 10%.
  18. if new_v < 0.9 * v:
  19. new_v += divisor
  20. return new_v
  21. class ConvBNReLU(nn.Sequential):
  22. def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
  23. padding = (kernel_size - 1) // 2
  24. super(ConvBNReLU, self).__init__(
  25. nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
  26. nn.BatchNorm2d(out_planes),
  27. nn.ReLU6(inplace=True)
  28. )
  29. class InvertedResidual(nn.Module):
  30. def __init__(self, inp, oup, stride, expand_ratio):
  31. super(InvertedResidual, self).__init__()
  32. self.stride = stride
  33. assert stride in [1, 2]
  34. hidden_dim = int(round(inp * expand_ratio))
  35. self.use_res_connect = self.stride == 1 and inp == oup
  36. layers = []
  37. if expand_ratio != 1:
  38. # pw
  39. layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
  40. layers.extend([
  41. # dw
  42. ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
  43. # pw-linear
  44. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
  45. nn.BatchNorm2d(oup),
  46. ])
  47. self.conv = nn.Sequential(*layers)
  48. def forward(self, x):
  49. if self.use_res_connect:
  50. return x + self.conv(x)
  51. else:
  52. return self.conv(x)
  53. class MobileNetV2(nn.Module):
  54. def __init__(self, pretrained=True):
  55. """
  56. MobileNet V2 main class
  57. Args:
  58. num_classes (int): Number of classes
  59. width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
  60. inverted_residual_setting: Network structure
  61. round_nearest (int): Round the number of channels in each layer to be a multiple of this number
  62. Set to 1 to turn off rounding
  63. block: Module specifying inverted residual building block for mobilenet
  64. """
  65. super(MobileNetV2, self).__init__()
  66. block = InvertedResidual
  67. input_channel = 32
  68. last_channel = 1280
  69. width_mult = 1.0
  70. round_nearest=8
  71. inverted_residual_setting = [
  72. # t, c, n, s
  73. [1, 16, 1, 1],
  74. [6, 24, 2, 2],
  75. [6, 32, 3, 2],
  76. [6, 64, 4, 2],
  77. [6, 96, 3, 1],
  78. [6, 160, 3, 2],
  79. [6, 320, 1, 1],
  80. ]
  81. # only check the first element, assuming user knows t,c,n,s are required
  82. if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
  83. raise ValueError("inverted_residual_setting should be non-empty "
  84. "or a 4-element list, got {}".format(inverted_residual_setting))
  85. # building first layer
  86. input_channel = _make_divisible(input_channel * width_mult, round_nearest)
  87. self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
  88. features = [ConvBNReLU(3, input_channel, stride=2)]
  89. # building inverted residual blocks
  90. for t, c, n, s in inverted_residual_setting:
  91. output_channel = _make_divisible(c * width_mult, round_nearest)
  92. for i in range(n):
  93. stride = s if i == 0 else 1
  94. features.append(block(input_channel, output_channel, stride, expand_ratio=t))
  95. input_channel = output_channel
  96. # building last several layers
  97. features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
  98. # make it nn.Sequential
  99. self.features = nn.Sequential(*features)
  100. # building classifier
  101. # self.classifier = nn.Sequential(
  102. # nn.Dropout(0.2),
  103. # nn.Linear(self.last_channel, num_classes),
  104. # )
  105. self.toplayer = nn.Conv2d(160, 32, kernel_size=1, stride=1, padding=0)
  106. self.latlayer1 = nn.Conv2d(64, 32, kernel_size=1, stride=1, padding=0)
  107. self.latlayer2 = nn.Conv2d( 32, 32, kernel_size=1, stride=1, padding=0)
  108. self.latlayer3 = nn.Conv2d( 24, 32, kernel_size=1, stride=1, padding=0)
  109. self.smooth1 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
  110. self.smooth2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
  111. self.smooth3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
  112. self.fpn_selected = [2, 5, 9, 15]
  113. # weight initialization
  114. for m in self.modules():
  115. if isinstance(m, nn.Conv2d):
  116. nn.init.kaiming_normal_(m.weight, mode='fan_out')
  117. if m.bias is not None:
  118. nn.init.zeros_(m.bias)
  119. elif isinstance(m, nn.BatchNorm2d):
  120. nn.init.ones_(m.weight)
  121. nn.init.zeros_(m.bias)
  122. elif isinstance(m, nn.Linear):
  123. nn.init.normal_(m.weight, 0, 0.01)
  124. nn.init.zeros_(m.bias)
  125. if pretrained:
  126. self._load_pretrained_model()
  127. def _forward_impl(self, x):
  128. # This exists since TorchScript doesn't support inheritance, so the superclass method
  129. # (this one) needs to have a name other than `forward` that can be accessed in a subclass
  130. fpn_features = []
  131. for i, f in enumerate(self.features):
  132. x = f(x)
  133. if i in self.fpn_selected:
  134. fpn_features.append(x)
  135. c2, c3, c4, c5 = fpn_features
  136. # Top-down
  137. p5 = self.toplayer(c5)
  138. p4 = nn.functional.upsample(p5, size=c4.size()[2:], mode='bilinear', align_corners=True) + self.latlayer1(c4)
  139. p3 = nn.functional.upsample(p4, size=c3.size()[2:], mode='bilinear', align_corners=True) + self.latlayer2(c3)
  140. p2 = nn.functional.upsample(p3, size=c2.size()[2:], mode='bilinear', align_corners=True) + self.latlayer3(c2)
  141. p4 = self.smooth1(p4)
  142. p3 = self.smooth2(p3)
  143. p2 = self.smooth3(p2)
  144. return p2, p3, p4, p5
  145. # x = self.features(x)
  146. # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
  147. # x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
  148. # x = self.classifier(x)
  149. # return x
  150. def forward(self, x):
  151. return self._forward_impl(x)
  152. def _load_pretrained_model(self):
  153. pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
  154. model_dict = {}
  155. state_dict = self.state_dict()
  156. for k, v in pretrain_dict.items():
  157. if k in state_dict:
  158. model_dict[k] = v
  159. state_dict.update(model_dict)
  160. self.load_state_dict(state_dict)
  161. def MobileNet_FPN(output_stride=None, BatchNorm=nn.BatchNorm2d, pretrained=True):
  162. """Constructs a ResNet-101 model.
  163. Args:
  164. pretrained (bool): If True, returns a model pre-trained on ImageNet
  165. """
  166. model = MobileNetV2(pretrained=pretrained)
  167. return model