hrnet.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. from paddle import fluid
  19. from paddle.fluid.param_attr import ParamAttr
  20. from paddle.fluid.framework import Variable
  21. from paddle.fluid.regularizer import L2Decay
  22. from ppdet.core.workspace import register, serializable
  23. from numbers import Integral
  24. from paddle.fluid.initializer import MSRA
  25. import math
  26. __all__ = ['HRNet']
  27. @register
  28. @serializable
  29. class HRNet(object):
  30. """
  31. HRNet, see https://arxiv.org/abs/1908.07919
  32. Args:
  33. width (int): network width, should be 18, 30, 32, 40, 44, 48, 60 or 64
  34. has_se (bool): whether contain squeeze_excitation(SE) block or not
  35. freeze_at (int): freeze the backbone at which stage
  36. norm_type (str): normalization type, 'bn'/'sync_bn'
  37. freeze_norm (bool): freeze normalization layers
  38. norm_decay (float): weight decay for normalization layer weights
  39. feature_maps (list): index of stages whose feature maps are returned
  40. """
  41. def __init__(self,
  42. width=40,
  43. has_se=False,
  44. freeze_at=2,
  45. norm_type='bn',
  46. freeze_norm=True,
  47. norm_decay=0.,
  48. feature_maps=[2, 3, 4, 5]):
  49. super(HRNet, self).__init__()
  50. if isinstance(feature_maps, Integral):
  51. feature_maps = [feature_maps]
  52. assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
  53. assert len(feature_maps) > 0, "need one or more feature maps"
  54. assert norm_type in ['bn', 'sync_bn']
  55. self.width = width
  56. self.has_se = has_se
  57. self.channels = {
  58. 18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
  59. 30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
  60. 32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
  61. 40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
  62. 44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
  63. 48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
  64. 60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
  65. 64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]],
  66. }
  67. self.freeze_at = freeze_at
  68. self.norm_type = norm_type
  69. self.norm_decay = norm_decay
  70. self.freeze_norm = freeze_norm
  71. self._model_type = 'HRNet'
  72. self.feature_maps = feature_maps
  73. self.end_points = []
  74. return
  75. def net(self, input, class_dim=1000):
  76. width = self.width
  77. channels_2, channels_3, channels_4 = self.channels[width]
  78. num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
  79. x = self.conv_bn_layer(
  80. input=input,
  81. filter_size=3,
  82. num_filters=64,
  83. stride=2,
  84. if_act=True,
  85. name='layer1_1')
  86. x = self.conv_bn_layer(
  87. input=x,
  88. filter_size=3,
  89. num_filters=64,
  90. stride=2,
  91. if_act=True,
  92. name='layer1_2')
  93. la1 = self.layer1(x, name='layer2')
  94. tr1 = self.transition_layer([la1], [256], channels_2, name='tr1')
  95. st2 = self.stage(tr1, num_modules_2, channels_2, name='st2')
  96. tr2 = self.transition_layer(st2, channels_2, channels_3, name='tr2')
  97. st3 = self.stage(tr2, num_modules_3, channels_3, name='st3')
  98. tr3 = self.transition_layer(st3, channels_3, channels_4, name='tr3')
  99. st4 = self.stage(tr3, num_modules_4, channels_4, name='st4')
  100. self.end_points = st4
  101. return st4[-1]
  102. def layer1(self, input, name=None):
  103. conv = input
  104. for i in range(4):
  105. conv = self.bottleneck_block(
  106. conv,
  107. num_filters=64,
  108. downsample=True if i == 0 else False,
  109. name=name + '_' + str(i + 1))
  110. return conv
  111. def transition_layer(self, x, in_channels, out_channels, name=None):
  112. num_in = len(in_channels)
  113. num_out = len(out_channels)
  114. out = []
  115. for i in range(num_out):
  116. if i < num_in:
  117. if in_channels[i] != out_channels[i]:
  118. residual = self.conv_bn_layer(
  119. x[i],
  120. filter_size=3,
  121. num_filters=out_channels[i],
  122. name=name + '_layer_' + str(i + 1))
  123. out.append(residual)
  124. else:
  125. out.append(x[i])
  126. else:
  127. residual = self.conv_bn_layer(
  128. x[-1],
  129. filter_size=3,
  130. num_filters=out_channels[i],
  131. stride=2,
  132. name=name + '_layer_' + str(i + 1))
  133. out.append(residual)
  134. return out
  135. def branches(self, x, block_num, channels, name=None):
  136. out = []
  137. for i in range(len(channels)):
  138. residual = x[i]
  139. for j in range(block_num):
  140. residual = self.basic_block(
  141. residual,
  142. channels[i],
  143. name=name + '_branch_layer_' + str(i + 1) + '_' +
  144. str(j + 1))
  145. out.append(residual)
  146. return out
  147. def fuse_layers(self, x, channels, multi_scale_output=True, name=None):
  148. out = []
  149. for i in range(len(channels) if multi_scale_output else 1):
  150. residual = x[i]
  151. for j in range(len(channels)):
  152. if j > i:
  153. y = self.conv_bn_layer(
  154. x[j],
  155. filter_size=1,
  156. num_filters=channels[i],
  157. if_act=False,
  158. name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
  159. y = fluid.layers.resize_nearest(input=y, scale=2**(j - i))
  160. residual = fluid.layers.elementwise_add(
  161. x=residual, y=y, act=None)
  162. elif j < i:
  163. y = x[j]
  164. for k in range(i - j):
  165. if k == i - j - 1:
  166. y = self.conv_bn_layer(
  167. y,
  168. filter_size=3,
  169. num_filters=channels[i],
  170. stride=2,
  171. if_act=False,
  172. name=name + '_layer_' + str(i + 1) + '_' +
  173. str(j + 1) + '_' + str(k + 1))
  174. else:
  175. y = self.conv_bn_layer(
  176. y,
  177. filter_size=3,
  178. num_filters=channels[j],
  179. stride=2,
  180. name=name + '_layer_' + str(i + 1) + '_' +
  181. str(j + 1) + '_' + str(k + 1))
  182. residual = fluid.layers.elementwise_add(
  183. x=residual, y=y, act=None)
  184. residual = fluid.layers.relu(residual)
  185. out.append(residual)
  186. return out
  187. def high_resolution_module(self,
  188. x,
  189. channels,
  190. multi_scale_output=True,
  191. name=None):
  192. residual = self.branches(x, 4, channels, name=name)
  193. out = self.fuse_layers(
  194. residual,
  195. channels,
  196. multi_scale_output=multi_scale_output,
  197. name=name)
  198. return out
  199. def stage(self,
  200. x,
  201. num_modules,
  202. channels,
  203. multi_scale_output=True,
  204. name=None):
  205. out = x
  206. for i in range(num_modules):
  207. if i == num_modules - 1 and multi_scale_output == False:
  208. out = self.high_resolution_module(
  209. out,
  210. channels,
  211. multi_scale_output=False,
  212. name=name + '_' + str(i + 1))
  213. else:
  214. out = self.high_resolution_module(
  215. out, channels, name=name + '_' + str(i + 1))
  216. return out
  217. def last_cls_out(self, x, name=None):
  218. out = []
  219. num_filters_list = [128, 256, 512, 1024]
  220. for i in range(len(x)):
  221. out.append(
  222. self.conv_bn_layer(
  223. input=x[i],
  224. filter_size=1,
  225. num_filters=num_filters_list[i],
  226. name=name + 'conv_' + str(i + 1)))
  227. return out
  228. def basic_block(self,
  229. input,
  230. num_filters,
  231. stride=1,
  232. downsample=False,
  233. name=None):
  234. residual = input
  235. conv = self.conv_bn_layer(
  236. input=input,
  237. filter_size=3,
  238. num_filters=num_filters,
  239. stride=stride,
  240. name=name + '_conv1')
  241. conv = self.conv_bn_layer(
  242. input=conv,
  243. filter_size=3,
  244. num_filters=num_filters,
  245. if_act=False,
  246. name=name + '_conv2')
  247. if downsample:
  248. residual = self.conv_bn_layer(
  249. input=input,
  250. filter_size=1,
  251. num_filters=num_filters,
  252. if_act=False,
  253. name=name + '_downsample')
  254. if self.has_se:
  255. conv = self.squeeze_excitation(
  256. input=conv,
  257. num_channels=num_filters,
  258. reduction_ratio=16,
  259. name='fc' + name)
  260. return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
  261. def bottleneck_block(self,
  262. input,
  263. num_filters,
  264. stride=1,
  265. downsample=False,
  266. name=None):
  267. residual = input
  268. conv = self.conv_bn_layer(
  269. input=input,
  270. filter_size=1,
  271. num_filters=num_filters,
  272. name=name + '_conv1')
  273. conv = self.conv_bn_layer(
  274. input=conv,
  275. filter_size=3,
  276. num_filters=num_filters,
  277. stride=stride,
  278. name=name + '_conv2')
  279. conv = self.conv_bn_layer(
  280. input=conv,
  281. filter_size=1,
  282. num_filters=num_filters * 4,
  283. if_act=False,
  284. name=name + '_conv3')
  285. if downsample:
  286. residual = self.conv_bn_layer(
  287. input=input,
  288. filter_size=1,
  289. num_filters=num_filters * 4,
  290. if_act=False,
  291. name=name + '_downsample')
  292. if self.has_se:
  293. conv = self.squeeze_excitation(
  294. input=conv,
  295. num_channels=num_filters * 4,
  296. reduction_ratio=16,
  297. name='fc' + name)
  298. return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
  299. def squeeze_excitation(self,
  300. input,
  301. num_channels,
  302. reduction_ratio,
  303. name=None):
  304. pool = fluid.layers.pool2d(
  305. input=input, pool_size=0, pool_type='avg', global_pooling=True)
  306. stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
  307. squeeze = fluid.layers.fc(
  308. input=pool,
  309. size=num_channels / reduction_ratio,
  310. act='relu',
  311. param_attr=fluid.param_attr.ParamAttr(
  312. initializer=fluid.initializer.Uniform(-stdv, stdv),
  313. name=name + '_sqz_weights'),
  314. bias_attr=ParamAttr(name=name + '_sqz_offset'))
  315. stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
  316. excitation = fluid.layers.fc(
  317. input=squeeze,
  318. size=num_channels,
  319. act='sigmoid',
  320. param_attr=fluid.param_attr.ParamAttr(
  321. initializer=fluid.initializer.Uniform(-stdv, stdv),
  322. name=name + '_exc_weights'),
  323. bias_attr=ParamAttr(name=name + '_exc_offset'))
  324. scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
  325. return scale
  326. def conv_bn_layer(self,
  327. input,
  328. filter_size,
  329. num_filters,
  330. stride=1,
  331. padding=1,
  332. num_groups=1,
  333. if_act=True,
  334. name=None):
  335. conv = fluid.layers.conv2d(
  336. input=input,
  337. num_filters=num_filters,
  338. filter_size=filter_size,
  339. stride=stride,
  340. padding=(filter_size - 1) // 2,
  341. groups=num_groups,
  342. act=None,
  343. param_attr=ParamAttr(
  344. initializer=MSRA(), name=name + '_weights'),
  345. bias_attr=False)
  346. bn_name = name + '_bn'
  347. bn = self._bn(input=conv, bn_name=bn_name)
  348. if if_act:
  349. bn = fluid.layers.relu(bn)
  350. return bn
  351. def _bn(self, input, act=None, bn_name=None):
  352. norm_lr = 0. if self.freeze_norm else 1.
  353. norm_decay = self.norm_decay
  354. pattr = ParamAttr(
  355. name=bn_name + '_scale',
  356. learning_rate=norm_lr,
  357. regularizer=L2Decay(norm_decay))
  358. battr = ParamAttr(
  359. name=bn_name + '_offset',
  360. learning_rate=norm_lr,
  361. regularizer=L2Decay(norm_decay))
  362. global_stats = True if self.freeze_norm else False
  363. out = fluid.layers.batch_norm(
  364. input=input,
  365. act=act,
  366. name=bn_name + '.output.1',
  367. param_attr=pattr,
  368. bias_attr=battr,
  369. moving_mean_name=bn_name + '_mean',
  370. moving_variance_name=bn_name + '_variance',
  371. use_global_stats=global_stats)
  372. scale = fluid.framework._get_var(pattr.name)
  373. bias = fluid.framework._get_var(battr.name)
  374. if self.freeze_norm:
  375. scale.stop_gradient = True
  376. bias.stop_gradient = True
  377. return out
  378. def __call__(self, input):
  379. assert isinstance(input, Variable)
  380. assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \
  381. "feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps)
  382. res_endpoints = []
  383. res = input
  384. feature_maps = self.feature_maps
  385. self.net(input)
  386. for i in feature_maps:
  387. res = self.end_points[i - 2]
  388. if i in self.feature_maps:
  389. res_endpoints.append(res)
  390. if self.freeze_at >= i:
  391. res.stop_gradient = True
  392. return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
  393. for idx, feat in enumerate(res_endpoints)])