hrnet.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn import AdaptiveAvgPool2D, Linear
  18. from paddle.regularizer import L2Decay
  19. from paddle import ParamAttr
  20. from paddle.nn.initializer import Normal, Uniform
  21. from numbers import Integral
  22. import math
  23. from ppdet.core.workspace import register
  24. from ..shape_spec import ShapeSpec
  25. __all__ = ['HRNet']
  26. class ConvNormLayer(nn.Layer):
  27. def __init__(self,
  28. ch_in,
  29. ch_out,
  30. filter_size,
  31. stride=1,
  32. norm_type='bn',
  33. norm_groups=32,
  34. use_dcn=False,
  35. norm_decay=0.,
  36. freeze_norm=False,
  37. act=None,
  38. name=None):
  39. super(ConvNormLayer, self).__init__()
  40. assert norm_type in ['bn', 'sync_bn', 'gn']
  41. self.act = act
  42. self.conv = nn.Conv2D(
  43. in_channels=ch_in,
  44. out_channels=ch_out,
  45. kernel_size=filter_size,
  46. stride=stride,
  47. padding=(filter_size - 1) // 2,
  48. groups=1,
  49. weight_attr=ParamAttr(initializer=Normal(
  50. mean=0., std=0.01)),
  51. bias_attr=False)
  52. norm_lr = 0. if freeze_norm else 1.
  53. param_attr = ParamAttr(
  54. learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
  55. bias_attr = ParamAttr(
  56. learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
  57. global_stats = True if freeze_norm else None
  58. if norm_type in ['bn', 'sync_bn']:
  59. self.norm = nn.BatchNorm2D(
  60. ch_out,
  61. weight_attr=param_attr,
  62. bias_attr=bias_attr,
  63. use_global_stats=global_stats)
  64. elif norm_type == 'gn':
  65. self.norm = nn.GroupNorm(
  66. num_groups=norm_groups,
  67. num_channels=ch_out,
  68. weight_attr=param_attr,
  69. bias_attr=bias_attr)
  70. norm_params = self.norm.parameters()
  71. if freeze_norm:
  72. for param in norm_params:
  73. param.stop_gradient = True
  74. def forward(self, inputs):
  75. out = self.conv(inputs)
  76. out = self.norm(out)
  77. if self.act == 'relu':
  78. out = F.relu(out)
  79. return out
  80. class Layer1(nn.Layer):
  81. def __init__(self,
  82. num_channels,
  83. has_se=False,
  84. norm_decay=0.,
  85. freeze_norm=True,
  86. name=None):
  87. super(Layer1, self).__init__()
  88. self.bottleneck_block_list = []
  89. for i in range(4):
  90. bottleneck_block = self.add_sublayer(
  91. "block_{}_{}".format(name, i + 1),
  92. BottleneckBlock(
  93. num_channels=num_channels if i == 0 else 256,
  94. num_filters=64,
  95. has_se=has_se,
  96. stride=1,
  97. downsample=True if i == 0 else False,
  98. norm_decay=norm_decay,
  99. freeze_norm=freeze_norm,
  100. name=name + '_' + str(i + 1)))
  101. self.bottleneck_block_list.append(bottleneck_block)
  102. def forward(self, input):
  103. conv = input
  104. for block_func in self.bottleneck_block_list:
  105. conv = block_func(conv)
  106. return conv
  107. class TransitionLayer(nn.Layer):
  108. def __init__(self,
  109. in_channels,
  110. out_channels,
  111. norm_decay=0.,
  112. freeze_norm=True,
  113. name=None):
  114. super(TransitionLayer, self).__init__()
  115. num_in = len(in_channels)
  116. num_out = len(out_channels)
  117. out = []
  118. self.conv_bn_func_list = []
  119. for i in range(num_out):
  120. residual = None
  121. if i < num_in:
  122. if in_channels[i] != out_channels[i]:
  123. residual = self.add_sublayer(
  124. "transition_{}_layer_{}".format(name, i + 1),
  125. ConvNormLayer(
  126. ch_in=in_channels[i],
  127. ch_out=out_channels[i],
  128. filter_size=3,
  129. norm_decay=norm_decay,
  130. freeze_norm=freeze_norm,
  131. act='relu',
  132. name=name + '_layer_' + str(i + 1)))
  133. else:
  134. residual = self.add_sublayer(
  135. "transition_{}_layer_{}".format(name, i + 1),
  136. ConvNormLayer(
  137. ch_in=in_channels[-1],
  138. ch_out=out_channels[i],
  139. filter_size=3,
  140. stride=2,
  141. norm_decay=norm_decay,
  142. freeze_norm=freeze_norm,
  143. act='relu',
  144. name=name + '_layer_' + str(i + 1)))
  145. self.conv_bn_func_list.append(residual)
  146. def forward(self, input):
  147. outs = []
  148. for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
  149. if conv_bn_func is None:
  150. outs.append(input[idx])
  151. else:
  152. if idx < len(input):
  153. outs.append(conv_bn_func(input[idx]))
  154. else:
  155. outs.append(conv_bn_func(input[-1]))
  156. return outs
  157. class Branches(nn.Layer):
  158. def __init__(self,
  159. block_num,
  160. in_channels,
  161. out_channels,
  162. has_se=False,
  163. norm_decay=0.,
  164. freeze_norm=True,
  165. name=None):
  166. super(Branches, self).__init__()
  167. self.basic_block_list = []
  168. for i in range(len(out_channels)):
  169. self.basic_block_list.append([])
  170. for j in range(block_num):
  171. in_ch = in_channels[i] if j == 0 else out_channels[i]
  172. basic_block_func = self.add_sublayer(
  173. "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
  174. BasicBlock(
  175. num_channels=in_ch,
  176. num_filters=out_channels[i],
  177. has_se=has_se,
  178. norm_decay=norm_decay,
  179. freeze_norm=freeze_norm,
  180. name=name + '_branch_layer_' + str(i + 1) + '_' +
  181. str(j + 1)))
  182. self.basic_block_list[i].append(basic_block_func)
  183. def forward(self, inputs):
  184. outs = []
  185. for idx, input in enumerate(inputs):
  186. conv = input
  187. basic_block_list = self.basic_block_list[idx]
  188. for basic_block_func in basic_block_list:
  189. conv = basic_block_func(conv)
  190. outs.append(conv)
  191. return outs
  192. class BottleneckBlock(nn.Layer):
  193. def __init__(self,
  194. num_channels,
  195. num_filters,
  196. has_se,
  197. stride=1,
  198. downsample=False,
  199. norm_decay=0.,
  200. freeze_norm=True,
  201. name=None):
  202. super(BottleneckBlock, self).__init__()
  203. self.has_se = has_se
  204. self.downsample = downsample
  205. self.conv1 = ConvNormLayer(
  206. ch_in=num_channels,
  207. ch_out=num_filters,
  208. filter_size=1,
  209. norm_decay=norm_decay,
  210. freeze_norm=freeze_norm,
  211. act="relu",
  212. name=name + "_conv1")
  213. self.conv2 = ConvNormLayer(
  214. ch_in=num_filters,
  215. ch_out=num_filters,
  216. filter_size=3,
  217. stride=stride,
  218. norm_decay=norm_decay,
  219. freeze_norm=freeze_norm,
  220. act="relu",
  221. name=name + "_conv2")
  222. self.conv3 = ConvNormLayer(
  223. ch_in=num_filters,
  224. ch_out=num_filters * 4,
  225. filter_size=1,
  226. norm_decay=norm_decay,
  227. freeze_norm=freeze_norm,
  228. act=None,
  229. name=name + "_conv3")
  230. if self.downsample:
  231. self.conv_down = ConvNormLayer(
  232. ch_in=num_channels,
  233. ch_out=num_filters * 4,
  234. filter_size=1,
  235. norm_decay=norm_decay,
  236. freeze_norm=freeze_norm,
  237. act=None,
  238. name=name + "_downsample")
  239. if self.has_se:
  240. self.se = SELayer(
  241. num_channels=num_filters * 4,
  242. num_filters=num_filters * 4,
  243. reduction_ratio=16,
  244. name='fc' + name)
  245. def forward(self, input):
  246. residual = input
  247. conv1 = self.conv1(input)
  248. conv2 = self.conv2(conv1)
  249. conv3 = self.conv3(conv2)
  250. if self.downsample:
  251. residual = self.conv_down(input)
  252. if self.has_se:
  253. conv3 = self.se(conv3)
  254. y = paddle.add(x=residual, y=conv3)
  255. y = F.relu(y)
  256. return y
  257. class BasicBlock(nn.Layer):
  258. def __init__(self,
  259. num_channels,
  260. num_filters,
  261. stride=1,
  262. has_se=False,
  263. downsample=False,
  264. norm_decay=0.,
  265. freeze_norm=True,
  266. name=None):
  267. super(BasicBlock, self).__init__()
  268. self.has_se = has_se
  269. self.downsample = downsample
  270. self.conv1 = ConvNormLayer(
  271. ch_in=num_channels,
  272. ch_out=num_filters,
  273. filter_size=3,
  274. norm_decay=norm_decay,
  275. freeze_norm=freeze_norm,
  276. stride=stride,
  277. act="relu",
  278. name=name + "_conv1")
  279. self.conv2 = ConvNormLayer(
  280. ch_in=num_filters,
  281. ch_out=num_filters,
  282. filter_size=3,
  283. norm_decay=norm_decay,
  284. freeze_norm=freeze_norm,
  285. stride=1,
  286. act=None,
  287. name=name + "_conv2")
  288. if self.downsample:
  289. self.conv_down = ConvNormLayer(
  290. ch_in=num_channels,
  291. ch_out=num_filters * 4,
  292. filter_size=1,
  293. norm_decay=norm_decay,
  294. freeze_norm=freeze_norm,
  295. act=None,
  296. name=name + "_downsample")
  297. if self.has_se:
  298. self.se = SELayer(
  299. num_channels=num_filters,
  300. num_filters=num_filters,
  301. reduction_ratio=16,
  302. name='fc' + name)
  303. def forward(self, input):
  304. residual = input
  305. conv1 = self.conv1(input)
  306. conv2 = self.conv2(conv1)
  307. if self.downsample:
  308. residual = self.conv_down(input)
  309. if self.has_se:
  310. conv2 = self.se(conv2)
  311. y = paddle.add(x=residual, y=conv2)
  312. y = F.relu(y)
  313. return y
  314. class SELayer(nn.Layer):
  315. def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
  316. super(SELayer, self).__init__()
  317. self.pool2d_gap = AdaptiveAvgPool2D(1)
  318. self._num_channels = num_channels
  319. med_ch = int(num_channels / reduction_ratio)
  320. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  321. self.squeeze = Linear(
  322. num_channels,
  323. med_ch,
  324. weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
  325. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  326. self.excitation = Linear(
  327. med_ch,
  328. num_filters,
  329. weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
  330. def forward(self, input):
  331. pool = self.pool2d_gap(input)
  332. pool = paddle.squeeze(pool, axis=[2, 3])
  333. squeeze = self.squeeze(pool)
  334. squeeze = F.relu(squeeze)
  335. excitation = self.excitation(squeeze)
  336. excitation = F.sigmoid(excitation)
  337. excitation = paddle.unsqueeze(excitation, axis=[2, 3])
  338. out = input * excitation
  339. return out
  340. class Stage(nn.Layer):
  341. def __init__(self,
  342. num_channels,
  343. num_modules,
  344. num_filters,
  345. has_se=False,
  346. norm_decay=0.,
  347. freeze_norm=True,
  348. multi_scale_output=True,
  349. name=None):
  350. super(Stage, self).__init__()
  351. self._num_modules = num_modules
  352. self.stage_func_list = []
  353. for i in range(num_modules):
  354. if i == num_modules - 1 and not multi_scale_output:
  355. stage_func = self.add_sublayer(
  356. "stage_{}_{}".format(name, i + 1),
  357. HighResolutionModule(
  358. num_channels=num_channels,
  359. num_filters=num_filters,
  360. has_se=has_se,
  361. norm_decay=norm_decay,
  362. freeze_norm=freeze_norm,
  363. multi_scale_output=False,
  364. name=name + '_' + str(i + 1)))
  365. else:
  366. stage_func = self.add_sublayer(
  367. "stage_{}_{}".format(name, i + 1),
  368. HighResolutionModule(
  369. num_channels=num_channels,
  370. num_filters=num_filters,
  371. has_se=has_se,
  372. norm_decay=norm_decay,
  373. freeze_norm=freeze_norm,
  374. name=name + '_' + str(i + 1)))
  375. self.stage_func_list.append(stage_func)
  376. def forward(self, input):
  377. out = input
  378. for idx in range(self._num_modules):
  379. out = self.stage_func_list[idx](out)
  380. return out
  381. class HighResolutionModule(nn.Layer):
  382. def __init__(self,
  383. num_channels,
  384. num_filters,
  385. has_se=False,
  386. multi_scale_output=True,
  387. norm_decay=0.,
  388. freeze_norm=True,
  389. name=None):
  390. super(HighResolutionModule, self).__init__()
  391. self.branches_func = Branches(
  392. block_num=4,
  393. in_channels=num_channels,
  394. out_channels=num_filters,
  395. has_se=has_se,
  396. norm_decay=norm_decay,
  397. freeze_norm=freeze_norm,
  398. name=name)
  399. self.fuse_func = FuseLayers(
  400. in_channels=num_filters,
  401. out_channels=num_filters,
  402. multi_scale_output=multi_scale_output,
  403. norm_decay=norm_decay,
  404. freeze_norm=freeze_norm,
  405. name=name)
  406. def forward(self, input):
  407. out = self.branches_func(input)
  408. out = self.fuse_func(out)
  409. return out
  410. class FuseLayers(nn.Layer):
  411. def __init__(self,
  412. in_channels,
  413. out_channels,
  414. multi_scale_output=True,
  415. norm_decay=0.,
  416. freeze_norm=True,
  417. name=None):
  418. super(FuseLayers, self).__init__()
  419. self._actual_ch = len(in_channels) if multi_scale_output else 1
  420. self._in_channels = in_channels
  421. self.residual_func_list = []
  422. for i in range(self._actual_ch):
  423. for j in range(len(in_channels)):
  424. residual_func = None
  425. if j > i:
  426. residual_func = self.add_sublayer(
  427. "residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
  428. ConvNormLayer(
  429. ch_in=in_channels[j],
  430. ch_out=out_channels[i],
  431. filter_size=1,
  432. stride=1,
  433. act=None,
  434. norm_decay=norm_decay,
  435. freeze_norm=freeze_norm,
  436. name=name + '_layer_' + str(i + 1) + '_' +
  437. str(j + 1)))
  438. self.residual_func_list.append(residual_func)
  439. elif j < i:
  440. pre_num_filters = in_channels[j]
  441. for k in range(i - j):
  442. if k == i - j - 1:
  443. residual_func = self.add_sublayer(
  444. "residual_{}_layer_{}_{}_{}".format(
  445. name, i + 1, j + 1, k + 1),
  446. ConvNormLayer(
  447. ch_in=pre_num_filters,
  448. ch_out=out_channels[i],
  449. filter_size=3,
  450. stride=2,
  451. norm_decay=norm_decay,
  452. freeze_norm=freeze_norm,
  453. act=None,
  454. name=name + '_layer_' + str(i + 1) + '_' +
  455. str(j + 1) + '_' + str(k + 1)))
  456. pre_num_filters = out_channels[i]
  457. else:
  458. residual_func = self.add_sublayer(
  459. "residual_{}_layer_{}_{}_{}".format(
  460. name, i + 1, j + 1, k + 1),
  461. ConvNormLayer(
  462. ch_in=pre_num_filters,
  463. ch_out=out_channels[j],
  464. filter_size=3,
  465. stride=2,
  466. norm_decay=norm_decay,
  467. freeze_norm=freeze_norm,
  468. act="relu",
  469. name=name + '_layer_' + str(i + 1) + '_' +
  470. str(j + 1) + '_' + str(k + 1)))
  471. pre_num_filters = out_channels[j]
  472. self.residual_func_list.append(residual_func)
  473. def forward(self, input):
  474. outs = []
  475. residual_func_idx = 0
  476. for i in range(self._actual_ch):
  477. residual = input[i]
  478. for j in range(len(self._in_channels)):
  479. if j > i:
  480. y = self.residual_func_list[residual_func_idx](input[j])
  481. residual_func_idx += 1
  482. y = F.interpolate(y, scale_factor=2**(j - i))
  483. residual = paddle.add(x=residual, y=y)
  484. elif j < i:
  485. y = input[j]
  486. for k in range(i - j):
  487. y = self.residual_func_list[residual_func_idx](y)
  488. residual_func_idx += 1
  489. residual = paddle.add(x=residual, y=y)
  490. residual = F.relu(residual)
  491. outs.append(residual)
  492. return outs
  493. @register
  494. class HRNet(nn.Layer):
  495. """
  496. HRNet, see https://arxiv.org/abs/1908.07919
  497. Args:
  498. width (int): the width of HRNet
  499. has_se (bool): whether to add SE block for each stage
  500. freeze_at (int): the stage to freeze
  501. freeze_norm (bool): whether to freeze norm in HRNet
  502. norm_decay (float): weight decay for normalization layer weights
  503. return_idx (List): the stage to return
  504. upsample (bool): whether to upsample and concat the backbone feats
  505. """
  506. def __init__(self,
  507. width=18,
  508. has_se=False,
  509. freeze_at=0,
  510. freeze_norm=True,
  511. norm_decay=0.,
  512. return_idx=[0, 1, 2, 3],
  513. upsample=False):
  514. super(HRNet, self).__init__()
  515. self.width = width
  516. self.has_se = has_se
  517. if isinstance(return_idx, Integral):
  518. return_idx = [return_idx]
  519. assert len(return_idx) > 0, "need one or more return index"
  520. self.freeze_at = freeze_at
  521. self.return_idx = return_idx
  522. self.upsample = upsample
  523. self.channels = {
  524. 18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
  525. 30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
  526. 32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
  527. 40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
  528. 44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
  529. 48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
  530. 60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
  531. 64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]]
  532. }
  533. channels_2, channels_3, channels_4 = self.channels[width]
  534. num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
  535. self._out_channels = [sum(channels_4)] if self.upsample else channels_4
  536. self._out_strides = [4] if self.upsample else [4, 8, 16, 32]
  537. self.conv_layer1_1 = ConvNormLayer(
  538. ch_in=3,
  539. ch_out=64,
  540. filter_size=3,
  541. stride=2,
  542. norm_decay=norm_decay,
  543. freeze_norm=freeze_norm,
  544. act='relu',
  545. name="layer1_1")
  546. self.conv_layer1_2 = ConvNormLayer(
  547. ch_in=64,
  548. ch_out=64,
  549. filter_size=3,
  550. stride=2,
  551. norm_decay=norm_decay,
  552. freeze_norm=freeze_norm,
  553. act='relu',
  554. name="layer1_2")
  555. self.la1 = Layer1(
  556. num_channels=64,
  557. has_se=has_se,
  558. norm_decay=norm_decay,
  559. freeze_norm=freeze_norm,
  560. name="layer2")
  561. self.tr1 = TransitionLayer(
  562. in_channels=[256],
  563. out_channels=channels_2,
  564. norm_decay=norm_decay,
  565. freeze_norm=freeze_norm,
  566. name="tr1")
  567. self.st2 = Stage(
  568. num_channels=channels_2,
  569. num_modules=num_modules_2,
  570. num_filters=channels_2,
  571. has_se=self.has_se,
  572. norm_decay=norm_decay,
  573. freeze_norm=freeze_norm,
  574. name="st2")
  575. self.tr2 = TransitionLayer(
  576. in_channels=channels_2,
  577. out_channels=channels_3,
  578. norm_decay=norm_decay,
  579. freeze_norm=freeze_norm,
  580. name="tr2")
  581. self.st3 = Stage(
  582. num_channels=channels_3,
  583. num_modules=num_modules_3,
  584. num_filters=channels_3,
  585. has_se=self.has_se,
  586. norm_decay=norm_decay,
  587. freeze_norm=freeze_norm,
  588. name="st3")
  589. self.tr3 = TransitionLayer(
  590. in_channels=channels_3,
  591. out_channels=channels_4,
  592. norm_decay=norm_decay,
  593. freeze_norm=freeze_norm,
  594. name="tr3")
  595. self.st4 = Stage(
  596. num_channels=channels_4,
  597. num_modules=num_modules_4,
  598. num_filters=channels_4,
  599. has_se=self.has_se,
  600. norm_decay=norm_decay,
  601. freeze_norm=freeze_norm,
  602. multi_scale_output=len(return_idx) > 1,
  603. name="st4")
  604. def forward(self, inputs):
  605. x = inputs['image']
  606. conv1 = self.conv_layer1_1(x)
  607. conv2 = self.conv_layer1_2(conv1)
  608. la1 = self.la1(conv2)
  609. tr1 = self.tr1([la1])
  610. st2 = self.st2(tr1)
  611. tr2 = self.tr2(st2)
  612. st3 = self.st3(tr2)
  613. tr3 = self.tr3(st3)
  614. st4 = self.st4(tr3)
  615. if self.upsample:
  616. # Upsampling
  617. x0_h, x0_w = st4[0].shape[2:4]
  618. x1 = F.upsample(st4[1], size=(x0_h, x0_w), mode='bilinear')
  619. x2 = F.upsample(st4[2], size=(x0_h, x0_w), mode='bilinear')
  620. x3 = F.upsample(st4[3], size=(x0_h, x0_w), mode='bilinear')
  621. x = paddle.concat([st4[0], x1, x2, x3], 1)
  622. return x
  623. res = []
  624. for i, layer in enumerate(st4):
  625. if i == self.freeze_at:
  626. layer.stop_gradient = True
  627. if i in self.return_idx:
  628. res.append(layer)
  629. return res
  630. @property
  631. def out_shape(self):
  632. if self.upsample:
  633. self.return_idx = [0]
  634. return [
  635. ShapeSpec(
  636. channels=self._out_channels[i], stride=self._out_strides[i])
  637. for i in self.return_idx
  638. ]